1 /*
2 * Copyright (c) 2020 Andri Yngvason
3 *
4 * Permission to use, copy, modify, and/or distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
9 * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
10 * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
11 * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
12 * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
13 * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
14 * PERFORMANCE OF THIS SOFTWARE.
15 */
16
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <unistd.h>
20 #include <string.h>
21 #include <assert.h>
22 #include <errno.h>
23 #include <sys/uio.h>
24 #include <limits.h>
25 #include <aml.h>
26 #include <fcntl.h>
27 #include <poll.h>
28
29 #ifdef ENABLE_TLS
30 #include <gnutls/gnutls.h>
31 #endif
32
33 #include "rcbuf.h"
34 #include "stream.h"
35 #include "sys/queue.h"
36
37 static void stream__on_event(void* obj);
38 #ifdef ENABLE_TLS
39 static int stream__try_tls_accept(struct stream* self);
40 #endif
41
stream__poll_r(struct stream * self)42 static inline void stream__poll_r(struct stream* self)
43 {
44 aml_set_event_mask(self->handler, AML_EVENT_READ);
45 }
46
stream__poll_w(struct stream * self)47 static inline void stream__poll_w(struct stream* self)
48 {
49 aml_set_event_mask(self->handler, AML_EVENT_WRITE);
50 }
51
stream__poll_rw(struct stream * self)52 static inline void stream__poll_rw(struct stream* self)
53 {
54 aml_set_event_mask(self->handler, AML_EVENT_READ | AML_EVENT_WRITE);
55 }
56
stream_req__finish(struct stream_req * req,enum stream_req_status status)57 static void stream_req__finish(struct stream_req* req,
58 enum stream_req_status status)
59 {
60 if (req->on_done)
61 req->on_done(req->userdata, status);
62
63 rcbuf_unref(req->payload);
64 free(req);
65 }
66
stream_close(struct stream * self)67 int stream_close(struct stream* self)
68 {
69 if (self->state == STREAM_STATE_CLOSED)
70 return -1;
71
72 self->state = STREAM_STATE_CLOSED;
73
74 while (!TAILQ_EMPTY(&self->send_queue)) {
75 struct stream_req* req = TAILQ_FIRST(&self->send_queue);
76 TAILQ_REMOVE(&self->send_queue, req, link);
77 stream_req__finish(req, STREAM_REQ_FAILED);
78 }
79
80 #ifdef ENABLE_TLS
81 if (self->tls_session)
82 gnutls_deinit(self->tls_session);
83 self->tls_session = NULL;
84 #endif
85
86 // TODO: Maybe use explicit loop object instead of the default one?
87 aml_stop(aml_get_default(), self->handler);
88 close(self->fd);
89 self->fd = -1;
90
91 return 0;
92 }
93
stream_destroy(struct stream * self)94 void stream_destroy(struct stream* self)
95 {
96 stream_close(self);
97 aml_unref(self->handler);
98 }
99
stream__remote_closed(struct stream * self)100 static void stream__remote_closed(struct stream* self)
101 {
102 stream_close(self);
103
104 if (self->on_event)
105 self->on_event(self, STREAM_EVENT_REMOTE_CLOSED);
106 }
107
stream__flush_plain(struct stream * self)108 static int stream__flush_plain(struct stream* self)
109 {
110 static struct iovec iov[IOV_MAX];
111 size_t n_msgs = 0;
112 ssize_t bytes_sent;
113
114 struct stream_req* req;
115 TAILQ_FOREACH(req, &self->send_queue, link) {
116 iov[n_msgs].iov_base = req->payload->payload;
117 iov[n_msgs].iov_len = req->payload->size;
118
119 if (++n_msgs >= IOV_MAX)
120 break;
121 }
122
123 if (n_msgs == 0)
124 return 0;
125
126 bytes_sent = writev(self->fd, iov, n_msgs);
127 if (bytes_sent < 0) {
128 if (errno == EAGAIN || errno == EWOULDBLOCK) {
129 stream__poll_rw(self);
130 errno = EAGAIN;
131 } else if (errno == EPIPE) {
132 stream__remote_closed(self);
133 errno = EPIPE;
134 }
135
136 return bytes_sent;
137 }
138
139 self->bytes_sent += bytes_sent;
140
141 ssize_t bytes_left = bytes_sent;
142
143 struct stream_req* tmp;
144 TAILQ_FOREACH_SAFE(req, &self->send_queue, link, tmp) {
145 bytes_left -= req->payload->size;
146
147 if (bytes_left >= 0) {
148 TAILQ_REMOVE(&self->send_queue, req, link);
149 stream_req__finish(req, STREAM_REQ_DONE);
150 } else {
151 char* p = req->payload->payload;
152 size_t s = req->payload->size;
153 memmove(p, p + s + bytes_left, -bytes_left);
154 req->payload->size = -bytes_left;
155 stream__poll_rw(self);
156 }
157
158 if (bytes_left <= 0)
159 break;
160 }
161
162 if (bytes_left == 0 && self->state != STREAM_STATE_CLOSED)
163 stream__poll_r(self);
164
165 assert(bytes_left <= 0);
166
167 return bytes_sent;
168 }
169
170 #ifdef ENABLE_TLS
stream__flush_tls(struct stream * self)171 static int stream__flush_tls(struct stream* self)
172 {
173 while (!TAILQ_EMPTY(&self->send_queue)) {
174 struct stream_req* req = TAILQ_FIRST(&self->send_queue);
175
176 ssize_t rc = gnutls_record_send(
177 self->tls_session, req->payload->payload,
178 req->payload->size);
179 if (rc < 0) {
180 gnutls_record_discard_queued(self->tls_session);
181 if (gnutls_error_is_fatal(rc))
182 stream_close(self);
183 return -1;
184 }
185
186 self->bytes_sent += rc;
187
188 ssize_t remaining = req->payload->size - rc;
189
190 if (remaining > 0) {
191 char* p = req->payload->payload;
192 size_t s = req->payload->size;
193 memmove(p, p + s - remaining, remaining);
194 req->payload->size = remaining;
195 stream__poll_rw(self);
196 return 1;
197 }
198
199 assert(remaining == 0);
200
201 TAILQ_REMOVE(&self->send_queue, req, link);
202 stream_req__finish(req, STREAM_REQ_DONE);
203 }
204
205 if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED)
206 stream__poll_r(self);
207
208 return 1;
209 }
210 #endif
211
stream__flush(struct stream * self)212 static int stream__flush(struct stream* self)
213 {
214 switch (self->state) {
215 case STREAM_STATE_NORMAL: return stream__flush_plain(self);
216 #ifdef ENABLE_TLS
217 case STREAM_STATE_TLS_READY: return stream__flush_tls(self);
218 #endif
219 default:
220 break;
221 }
222 abort();
223 return -1;
224 }
225
stream__on_readable(struct stream * self)226 static void stream__on_readable(struct stream* self)
227 {
228 switch (self->state) {
229 case STREAM_STATE_NORMAL:
230 /* fallthrough */
231 #ifdef ENABLE_TLS
232 case STREAM_STATE_TLS_READY:
233 #endif
234 if (self->on_event)
235 self->on_event(self, STREAM_EVENT_READ);
236 break;
237 #ifdef ENABLE_TLS
238 case STREAM_STATE_TLS_HANDSHAKE:
239 stream__try_tls_accept(self);
240 break;
241 #endif
242 case STREAM_STATE_CLOSED:
243 break;
244 }
245 }
246
stream__on_writable(struct stream * self)247 static void stream__on_writable(struct stream* self)
248 {
249 switch (self->state) {
250 case STREAM_STATE_NORMAL:
251 /* fallthrough */
252 #ifdef ENABLE_TLS
253 case STREAM_STATE_TLS_READY:
254 #endif
255 stream__flush(self);
256 break;
257 #ifdef ENABLE_TLS
258 case STREAM_STATE_TLS_HANDSHAKE:
259 stream__try_tls_accept(self);
260 break;
261 #endif
262 case STREAM_STATE_CLOSED:
263 break;
264 }
265 }
266
stream__on_event(void * obj)267 static void stream__on_event(void* obj)
268 {
269 struct stream* self = aml_get_userdata(obj);
270 uint32_t events = aml_get_revents(obj);
271
272 if (events & AML_EVENT_READ)
273 stream__on_readable(self);
274
275 if (events & AML_EVENT_WRITE)
276 stream__on_writable(self);
277 }
278
stream_new(int fd,stream_event_fn on_event,void * userdata)279 struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata)
280 {
281 struct stream* self = calloc(1, sizeof(*self));
282 if (!self)
283 return NULL;
284
285 self->fd = fd;
286 self->on_event = on_event;
287 self->userdata = userdata;
288
289 TAILQ_INIT(&self->send_queue);
290
291 fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
292
293 self->handler = aml_handler_new(fd, stream__on_event, self, free);
294 if (!self->handler)
295 goto failure;
296
297 if (aml_start(aml_get_default(), self->handler) < 0)
298 goto start_failure;
299
300 stream__poll_r(self);
301
302 return self;
303
304 start_failure:
305 aml_unref(self->handler);
306 self = NULL; /* Handled in unref */
307 failure:
308 free(self);
309 return NULL;
310 }
311
stream_send(struct stream * self,struct rcbuf * payload,stream_req_fn on_done,void * userdata)312 int stream_send(struct stream* self, struct rcbuf* payload,
313 stream_req_fn on_done, void* userdata)
314 {
315 if (self->state == STREAM_STATE_CLOSED)
316 return -1;
317
318 struct stream_req* req = calloc(1, sizeof(*req));
319 if (!req)
320 return -1;
321
322 req->payload = payload;
323 req->on_done = on_done;
324 req->userdata = userdata;
325
326 TAILQ_INSERT_TAIL(&self->send_queue, req, link);
327
328 return stream__flush(self);
329 }
330
stream_write(struct stream * self,const void * payload,size_t len,stream_req_fn on_done,void * userdata)331 int stream_write(struct stream* self, const void* payload, size_t len,
332 stream_req_fn on_done, void* userdata)
333 {
334 struct rcbuf* buf = rcbuf_from_mem(payload, len);
335 return buf ? stream_send(self, buf, on_done, userdata) : -1;
336 }
337
stream__read_plain(struct stream * self,void * dst,size_t size)338 static ssize_t stream__read_plain(struct stream* self, void* dst, size_t size)
339 {
340 ssize_t rc = read(self->fd, dst, size);
341 if (rc == 0)
342 stream__remote_closed(self);
343 if (rc > 0)
344 self->bytes_received += rc;
345 return rc;
346 }
347
348 #ifdef ENABLE_TLS
stream__read_tls(struct stream * self,void * dst,size_t size)349 static ssize_t stream__read_tls(struct stream* self, void* dst, size_t size)
350 {
351 ssize_t rc = gnutls_record_recv(self->tls_session, dst, size);
352 if (rc >= 0) {
353 self->bytes_received += rc;
354 return rc;
355 }
356
357 switch (rc) {
358 case GNUTLS_E_INTERRUPTED:
359 errno = EINTR;
360 break;
361 case GNUTLS_E_AGAIN:
362 errno = EAGAIN;
363 break;
364 default:
365 errno = 0;
366 break;
367 }
368
369 // Make sure data wasn't being written.
370 assert(gnutls_record_get_direction(self->tls_session) == 0);
371 return -1;
372 }
373 #endif
374
stream_read(struct stream * self,void * dst,size_t size)375 ssize_t stream_read(struct stream* self, void* dst, size_t size)
376 {
377 switch (self->state) {
378 case STREAM_STATE_NORMAL: return stream__read_plain(self, dst, size);
379 #ifdef ENABLE_TLS
380 case STREAM_STATE_TLS_READY: return stream__read_tls(self, dst, size);
381 #endif
382 default: break;
383 }
384
385 abort();
386 return -1;
387 }
388
389 #ifdef ENABLE_TLS
stream__try_tls_accept(struct stream * self)390 static int stream__try_tls_accept(struct stream* self)
391 {
392 int rc;
393
394 rc = gnutls_handshake(self->tls_session);
395 if (rc == GNUTLS_E_SUCCESS) {
396 self->state = STREAM_STATE_TLS_READY;
397 stream__poll_r(self);
398 return 0;
399 }
400
401 if (gnutls_error_is_fatal(rc)) {
402 aml_stop(aml_get_default(), &self->handler);
403 return -1;
404 }
405
406 int was_writing = gnutls_record_get_direction(self->tls_session);
407 if (was_writing)
408 stream__poll_w(self);
409 else
410 stream__poll_r(self);
411
412 self->state = STREAM_STATE_TLS_HANDSHAKE;
413 return 0;
414 }
415
stream_upgrade_to_tls(struct stream * self,void * context)416 int stream_upgrade_to_tls(struct stream* self, void* context)
417 {
418 int rc;
419
420 rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
421 if (rc != GNUTLS_E_SUCCESS)
422 return -1;
423
424 rc = gnutls_set_default_priority(self->tls_session);
425 if (rc != GNUTLS_E_SUCCESS)
426 goto failure;
427
428 rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE,
429 context);
430 if (rc != GNUTLS_E_SUCCESS)
431 goto failure;
432
433 gnutls_transport_set_int(self->tls_session, self->fd);
434
435 return stream__try_tls_accept(self);
436
437 failure:
438 gnutls_deinit(self->tls_session);
439 return -1;
440 }
441 #endif
442