1 /*
2  * Copyright (c) 2020 Andri Yngvason
3  *
4  * Permission to use, copy, modify, and/or distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  *
8  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
9  * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
10  * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
11  * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
12  * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
13  * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
14  * PERFORMANCE OF THIS SOFTWARE.
15  */
16 
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <unistd.h>
20 #include <string.h>
21 #include <assert.h>
22 #include <errno.h>
23 #include <sys/uio.h>
24 #include <limits.h>
25 #include <aml.h>
26 #include <fcntl.h>
27 #include <poll.h>
28 
29 #ifdef ENABLE_TLS
30 #include <gnutls/gnutls.h>
31 #endif
32 
33 #include "rcbuf.h"
34 #include "stream.h"
35 #include "sys/queue.h"
36 
37 static void stream__on_event(void* obj);
38 #ifdef ENABLE_TLS
39 static int stream__try_tls_accept(struct stream* self);
40 #endif
41 
stream__poll_r(struct stream * self)42 static inline void stream__poll_r(struct stream* self)
43 {
44 	aml_set_event_mask(self->handler, AML_EVENT_READ);
45 }
46 
stream__poll_w(struct stream * self)47 static inline void stream__poll_w(struct stream* self)
48 {
49 	aml_set_event_mask(self->handler, AML_EVENT_WRITE);
50 }
51 
stream__poll_rw(struct stream * self)52 static inline void stream__poll_rw(struct stream* self)
53 {
54 	aml_set_event_mask(self->handler, AML_EVENT_READ | AML_EVENT_WRITE);
55 }
56 
stream_req__finish(struct stream_req * req,enum stream_req_status status)57 static void stream_req__finish(struct stream_req* req,
58                                enum stream_req_status status)
59 {
60 	if (req->on_done)
61 		req->on_done(req->userdata, status);
62 
63 	rcbuf_unref(req->payload);
64 	free(req);
65 }
66 
stream_close(struct stream * self)67 int stream_close(struct stream* self)
68 {
69 	if (self->state == STREAM_STATE_CLOSED)
70 		return -1;
71 
72 	self->state = STREAM_STATE_CLOSED;
73 
74 	while (!TAILQ_EMPTY(&self->send_queue)) {
75 		struct stream_req* req = TAILQ_FIRST(&self->send_queue);
76 		TAILQ_REMOVE(&self->send_queue, req, link);
77 		stream_req__finish(req, STREAM_REQ_FAILED);
78 	}
79 
80 #ifdef ENABLE_TLS
81 	if (self->tls_session)
82 		gnutls_deinit(self->tls_session);
83 	self->tls_session = NULL;
84 #endif
85 
86 	// TODO: Maybe use explicit loop object instead of the default one?
87 	aml_stop(aml_get_default(), self->handler);
88 	close(self->fd);
89 	self->fd = -1;
90 
91 	return 0;
92 }
93 
stream_destroy(struct stream * self)94 void stream_destroy(struct stream* self)
95 {
96 	stream_close(self);
97 	aml_unref(self->handler);
98 }
99 
stream__remote_closed(struct stream * self)100 static void stream__remote_closed(struct stream* self)
101 {
102 	stream_close(self);
103 
104 	if (self->on_event)
105 		self->on_event(self, STREAM_EVENT_REMOTE_CLOSED);
106 }
107 
stream__flush_plain(struct stream * self)108 static int stream__flush_plain(struct stream* self)
109 {
110 	static struct iovec iov[IOV_MAX];
111 	size_t n_msgs = 0;
112 	ssize_t bytes_sent;
113 
114 	struct stream_req* req;
115 	TAILQ_FOREACH(req, &self->send_queue, link) {
116 		iov[n_msgs].iov_base = req->payload->payload;
117 		iov[n_msgs].iov_len = req->payload->size;
118 
119 		if (++n_msgs >= IOV_MAX)
120 			break;
121 	}
122 
123 	if (n_msgs == 0)
124 		return 0;
125 
126 	bytes_sent = writev(self->fd, iov, n_msgs);
127 	if (bytes_sent < 0) {
128 		if (errno == EAGAIN || errno == EWOULDBLOCK) {
129 			stream__poll_rw(self);
130 			errno = EAGAIN;
131 		} else if (errno == EPIPE) {
132 			stream__remote_closed(self);
133 			errno = EPIPE;
134 		}
135 
136 		return bytes_sent;
137 	}
138 
139 	self->bytes_sent += bytes_sent;
140 
141 	ssize_t bytes_left = bytes_sent;
142 
143 	struct stream_req* tmp;
144 	TAILQ_FOREACH_SAFE(req, &self->send_queue, link, tmp) {
145 		bytes_left -= req->payload->size;
146 
147 		if (bytes_left >= 0) {
148 			TAILQ_REMOVE(&self->send_queue, req, link);
149 			stream_req__finish(req, STREAM_REQ_DONE);
150 		} else {
151 			char* p = req->payload->payload;
152 			size_t s = req->payload->size;
153 			memmove(p, p + s + bytes_left, -bytes_left);
154 			req->payload->size = -bytes_left;
155 			stream__poll_rw(self);
156 		}
157 
158 		if (bytes_left <= 0)
159 			break;
160 	}
161 
162 	if (bytes_left == 0 && self->state != STREAM_STATE_CLOSED)
163 		stream__poll_r(self);
164 
165 	assert(bytes_left <= 0);
166 
167 	return bytes_sent;
168 }
169 
170 #ifdef ENABLE_TLS
stream__flush_tls(struct stream * self)171 static int stream__flush_tls(struct stream* self)
172 {
173 	while (!TAILQ_EMPTY(&self->send_queue)) {
174 		struct stream_req* req = TAILQ_FIRST(&self->send_queue);
175 
176 		ssize_t rc = gnutls_record_send(
177 			self->tls_session, req->payload->payload,
178 			req->payload->size);
179 		if (rc < 0) {
180 			gnutls_record_discard_queued(self->tls_session);
181 			if (gnutls_error_is_fatal(rc))
182 				stream_close(self);
183 			return -1;
184 		}
185 
186 		self->bytes_sent += rc;
187 
188 		ssize_t remaining = req->payload->size - rc;
189 
190 		if (remaining > 0) {
191 			char* p = req->payload->payload;
192 			size_t s = req->payload->size;
193 			memmove(p, p + s - remaining, remaining);
194 			req->payload->size = remaining;
195 			stream__poll_rw(self);
196 			return 1;
197 		}
198 
199 		assert(remaining == 0);
200 
201 		TAILQ_REMOVE(&self->send_queue, req, link);
202 		stream_req__finish(req, STREAM_REQ_DONE);
203 	}
204 
205 	if (TAILQ_EMPTY(&self->send_queue) && self->state != STREAM_STATE_CLOSED)
206 		stream__poll_r(self);
207 
208 	return 1;
209 }
210 #endif
211 
stream__flush(struct stream * self)212 static int stream__flush(struct stream* self)
213 {
214 	switch (self->state) {
215 	case STREAM_STATE_NORMAL: return stream__flush_plain(self);
216 #ifdef ENABLE_TLS
217 	case STREAM_STATE_TLS_READY: return stream__flush_tls(self);
218 #endif
219 	default:
220 		break;
221 	}
222 	abort();
223 	return -1;
224 }
225 
stream__on_readable(struct stream * self)226 static void stream__on_readable(struct stream* self)
227 {
228 	switch (self->state) {
229 	case STREAM_STATE_NORMAL:
230 		/* fallthrough */
231 #ifdef ENABLE_TLS
232 	case STREAM_STATE_TLS_READY:
233 #endif
234 		if (self->on_event)
235 			self->on_event(self, STREAM_EVENT_READ);
236 		break;
237 #ifdef ENABLE_TLS
238 	case STREAM_STATE_TLS_HANDSHAKE:
239 		stream__try_tls_accept(self);
240 		break;
241 #endif
242 	case STREAM_STATE_CLOSED:
243 		break;
244 	}
245 }
246 
stream__on_writable(struct stream * self)247 static void stream__on_writable(struct stream* self)
248 {
249 	switch (self->state) {
250 	case STREAM_STATE_NORMAL:
251 		/* fallthrough */
252 #ifdef ENABLE_TLS
253 	case STREAM_STATE_TLS_READY:
254 #endif
255 		stream__flush(self);
256 		break;
257 #ifdef ENABLE_TLS
258 	case STREAM_STATE_TLS_HANDSHAKE:
259 		stream__try_tls_accept(self);
260 		break;
261 #endif
262 	case STREAM_STATE_CLOSED:
263 		break;
264 	}
265 }
266 
stream__on_event(void * obj)267 static void stream__on_event(void* obj)
268 {
269 	struct stream* self = aml_get_userdata(obj);
270 	uint32_t events = aml_get_revents(obj);
271 
272 	if (events & AML_EVENT_READ)
273 		stream__on_readable(self);
274 
275 	if (events & AML_EVENT_WRITE)
276 		stream__on_writable(self);
277 }
278 
stream_new(int fd,stream_event_fn on_event,void * userdata)279 struct stream* stream_new(int fd, stream_event_fn on_event, void* userdata)
280 {
281 	struct stream* self = calloc(1, sizeof(*self));
282 	if (!self)
283 		return NULL;
284 
285 	self->fd = fd;
286 	self->on_event = on_event;
287 	self->userdata = userdata;
288 
289 	TAILQ_INIT(&self->send_queue);
290 
291 	fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
292 
293 	self->handler = aml_handler_new(fd, stream__on_event, self, free);
294 	if (!self->handler)
295 		goto failure;
296 
297 	if (aml_start(aml_get_default(), self->handler) < 0)
298 		goto start_failure;
299 
300 	stream__poll_r(self);
301 
302 	return self;
303 
304 start_failure:
305 	aml_unref(self->handler);
306 	self = NULL; /* Handled in unref */
307 failure:
308 	free(self);
309 	return NULL;
310 }
311 
stream_send(struct stream * self,struct rcbuf * payload,stream_req_fn on_done,void * userdata)312 int stream_send(struct stream* self, struct rcbuf* payload,
313                 stream_req_fn on_done, void* userdata)
314 {
315 	if (self->state == STREAM_STATE_CLOSED)
316 		return -1;
317 
318 	struct stream_req* req = calloc(1, sizeof(*req));
319 	if (!req)
320 		return -1;
321 
322 	req->payload = payload;
323 	req->on_done = on_done;
324 	req->userdata = userdata;
325 
326 	TAILQ_INSERT_TAIL(&self->send_queue, req, link);
327 
328 	return stream__flush(self);
329 }
330 
stream_write(struct stream * self,const void * payload,size_t len,stream_req_fn on_done,void * userdata)331 int stream_write(struct stream* self, const void* payload, size_t len,
332                  stream_req_fn on_done, void* userdata)
333 {
334 	struct rcbuf* buf = rcbuf_from_mem(payload, len);
335 	return buf ? stream_send(self, buf, on_done, userdata) : -1;
336 }
337 
stream__read_plain(struct stream * self,void * dst,size_t size)338 static ssize_t stream__read_plain(struct stream* self, void* dst, size_t size)
339 {
340 	ssize_t rc = read(self->fd, dst, size);
341 	if (rc == 0)
342 		stream__remote_closed(self);
343 	if (rc > 0)
344 		self->bytes_received += rc;
345 	return rc;
346 }
347 
348 #ifdef ENABLE_TLS
stream__read_tls(struct stream * self,void * dst,size_t size)349 static ssize_t stream__read_tls(struct stream* self, void* dst, size_t size)
350 {
351 	ssize_t rc = gnutls_record_recv(self->tls_session, dst, size);
352 	if (rc >= 0) {
353 		self->bytes_received += rc;
354 		return rc;
355 	}
356 
357 	switch (rc) {
358 	case GNUTLS_E_INTERRUPTED:
359 		errno = EINTR;
360 		break;
361 	case GNUTLS_E_AGAIN:
362 		errno = EAGAIN;
363 		break;
364 	default:
365 		errno = 0;
366 		break;
367 	}
368 
369 	// Make sure data wasn't being written.
370 	assert(gnutls_record_get_direction(self->tls_session) == 0);
371 	return -1;
372 }
373 #endif
374 
stream_read(struct stream * self,void * dst,size_t size)375 ssize_t stream_read(struct stream* self, void* dst, size_t size)
376 {
377 	switch (self->state) {
378 	case STREAM_STATE_NORMAL: return stream__read_plain(self, dst, size);
379 #ifdef ENABLE_TLS
380 	case STREAM_STATE_TLS_READY: return stream__read_tls(self, dst, size);
381 #endif
382 	default: break;
383 	}
384 
385 	abort();
386 	return -1;
387 }
388 
389 #ifdef ENABLE_TLS
stream__try_tls_accept(struct stream * self)390 static int stream__try_tls_accept(struct stream* self)
391 {
392 	int rc;
393 
394 	rc = gnutls_handshake(self->tls_session);
395 	if (rc == GNUTLS_E_SUCCESS) {
396 		self->state = STREAM_STATE_TLS_READY;
397 		stream__poll_r(self);
398 		return 0;
399 	}
400 
401 	if (gnutls_error_is_fatal(rc)) {
402 		aml_stop(aml_get_default(), &self->handler);
403 		return -1;
404 	}
405 
406 	int was_writing = gnutls_record_get_direction(self->tls_session);
407 	if (was_writing)
408 		stream__poll_w(self);
409 	else
410 		stream__poll_r(self);
411 
412 	self->state = STREAM_STATE_TLS_HANDSHAKE;
413 	return 0;
414 }
415 
stream_upgrade_to_tls(struct stream * self,void * context)416 int stream_upgrade_to_tls(struct stream* self, void* context)
417 {
418 	int rc;
419 
420 	rc = gnutls_init(&self->tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
421 	if (rc != GNUTLS_E_SUCCESS)
422 		return -1;
423 
424 	rc = gnutls_set_default_priority(self->tls_session);
425 	if (rc != GNUTLS_E_SUCCESS)
426 		goto failure;
427 
428 	rc = gnutls_credentials_set(self->tls_session, GNUTLS_CRD_CERTIFICATE,
429 	                            context);
430 	if (rc != GNUTLS_E_SUCCESS)
431 		goto failure;
432 
433 	gnutls_transport_set_int(self->tls_session, self->fd);
434 
435 	return stream__try_tls_accept(self);
436 
437 failure:
438 	gnutls_deinit(self->tls_session);
439 	return -1;
440 }
441 #endif
442