1 #include "sha1/sha1.h"
2 #include <b64/cencode.h>
3 #include "websocket.h"
4 #include "client.h"
5 #include "cmd.h"
6 #include "worker.h"
7 #include "pool.h"
8 #include "http.h"
9 #include "slog.h"
10 #include "server.h"
11 #include "conf.h"
12
13 /* message parsers */
14 #include "formats/json.h"
15 #include "formats/raw.h"
16
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <string.h>
20 #include <unistd.h>
21 #include <errno.h>
22 #include <sys/param.h>
23
24 static int
25 ws_schedule_write(struct ws_client *ws);
26
27 /**
28 * This code uses the WebSocket specification from RFC 6455.
29 * A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt
30 */
31 #if __BIG_ENDIAN__
32 # define webdis_htonll(x) (x)
33 # define webdis_ntohll(x) (x)
34 #else
35 # define webdis_htonll(x) (((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
36 # define webdis_ntohll(x) (((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
37 #endif
38
39 static int
ws_compute_handshake(struct http_client * c,char * out,size_t * out_sz)40 ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) {
41
42 unsigned char *buffer, sha1_output[20];
43 char magic[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
44 SHA1Context ctx;
45 base64_encodestate b64_ctx;
46 int pos, i;
47
48 // websocket handshake
49 const char *key = client_get_header(c, "Sec-WebSocket-Key");
50 size_t key_sz = key?strlen(key):0, buffer_sz = key_sz + sizeof(magic) - 1;
51 if(!key || key_sz < 16 || key_sz > 32) { /* supposed to be exactly 16 bytes that were b64 encoded */
52 slog(c->s, WEBDIS_WARNING, "Invalid Sec-WebSocket-Key", 0);
53 return -1;
54 }
55 buffer = calloc(buffer_sz, 1);
56 if(!buffer) {
57 slog(c->s, WEBDIS_ERROR, "Failed to allocate memory for WS header", 0);
58 return -1;
59 }
60
61 // concatenate key and guid in buffer
62 memcpy(buffer, key, key_sz);
63 memcpy(buffer+key_sz, magic, sizeof(magic)-1);
64
65 // compute sha-1
66 SHA1Reset(&ctx);
67 SHA1Input(&ctx, buffer, buffer_sz);
68 SHA1Result(&ctx);
69 for(i = 0; i < (int)(20/sizeof(int)); ++i) { // put in correct byte order before memcpy.
70 ctx.Message_Digest[i] = ntohl(ctx.Message_Digest[i]);
71 }
72 memcpy(sha1_output, ctx.Message_Digest, 20);
73
74 // encode `sha1_output' in base 64, into `out'.
75 base64_init_encodestate(&b64_ctx);
76 pos = base64_encode_block((const char*)sha1_output, 20, out, &b64_ctx);
77 base64_encode_blockend(out + pos, &b64_ctx);
78
79 // compute length, without \n
80 *out_sz = strlen(out);
81 if(out[*out_sz-1] == '\n')
82 (*out_sz)--;
83
84 free(buffer);
85
86 return 0;
87 }
88
89 struct ws_client *
ws_client_new(struct http_client * http_client)90 ws_client_new(struct http_client *http_client) {
91
92 int db_num = http_client->w->s->cfg->database;
93 struct ws_client *ws = calloc(1, sizeof(struct ws_client));
94 struct evbuffer *rbuf = evbuffer_new();
95 struct evbuffer *wbuf = evbuffer_new();
96 redisAsyncContext *ac = pool_connect(http_client->w->pool, db_num, 0);
97
98 if(!ws || !rbuf || !wbuf) {
99 slog(http_client->s, WEBDIS_ERROR, "Failed to allocate memory for WS client", 0);
100 if(ws) free(ws);
101 if(rbuf) evbuffer_free(rbuf);
102 if(wbuf) evbuffer_free(wbuf);
103 if(ac) redisAsyncFree(ac);
104 return NULL;
105 }
106
107 http_client->ws = ws;
108 ws->http_client = http_client;
109 ws->rbuf = rbuf;
110 ws->wbuf = wbuf;
111 ws->ac = ac;
112
113 return ws;
114 }
115
116 void
ws_client_free(struct ws_client * ws)117 ws_client_free(struct ws_client *ws) {
118
119 /* mark WS client as closing to skip the Redis callback */
120 ws->close_after_events = 1;
121 pool_free_context(ws->ac); /* could trigger a cb via format_send_error */
122
123 struct http_client *c = ws->http_client;
124 if(c) {
125 close(c->fd);
126 c->ws = NULL; /* detach if needed */
127 }
128 evbuffer_free(ws->rbuf);
129 evbuffer_free(ws->wbuf);
130 if(ws->cmd) {
131 ws->cmd->ac = NULL; /* we've just free'd it */
132 cmd_free(ws->cmd);
133 }
134 free(ws);
135 if(c) http_client_free(c);
136 }
137
138
139 int
ws_handshake_reply(struct ws_client * ws)140 ws_handshake_reply(struct ws_client *ws) {
141
142 struct http_client *c = ws->http_client;
143 char sha1_handshake[40];
144 char *buffer = NULL, *p;
145 const char *origin = NULL, *host = NULL;
146 size_t origin_sz = 0, host_sz = 0, handshake_sz = 0, sz;
147
148 char template_start[] = "HTTP/1.1 101 Switching Protocols\r\n"
149 "Upgrade: websocket\r\n"
150 "Connection: Upgrade";
151 char template_accept[] = "\r\n" /* just after the start */
152 "Sec-WebSocket-Accept: "; /* %s */
153 char template_sec_origin[] = "\r\n"
154 "Sec-WebSocket-Origin: "; /* %s (optional header) */
155 char template_loc[] = "\r\n"
156 "Sec-WebSocket-Location: ws://"; /* %s%s */
157 char template_end[] = "\r\n\r\n";
158
159 if((origin = client_get_header(c, "Origin"))) {
160 origin_sz = strlen(origin);
161 } else if((origin = client_get_header(c, "Sec-WebSocket-Origin"))) {
162 origin_sz = strlen(origin);
163 }
164 if((host = client_get_header(c, "Host"))) {
165 host_sz = strlen(host);
166 }
167
168 /* need those headers */
169 if(!host || !host_sz || !c->path || !c->path_sz) {
170 slog(c->s, WEBDIS_WARNING, "Missing headers for WS handshake", 0);
171 return -1;
172 }
173
174 memset(sha1_handshake, 0, sizeof(sha1_handshake));
175 if(ws_compute_handshake(c, &sha1_handshake[0], &handshake_sz) != 0) {
176 /* failed to compute handshake. */
177 slog(c->s, WEBDIS_WARNING, "Failed to compute handshake", 0);
178 return -1;
179 }
180
181 sz = sizeof(template_start)-1
182 + sizeof(template_accept)-1 + handshake_sz
183 + (origin && origin_sz ? (sizeof(template_sec_origin)-1 + origin_sz) : 0) /* optional origin */
184 + sizeof(template_loc)-1 + host_sz + c->path_sz
185 + sizeof(template_end)-1;
186
187 p = buffer = malloc(sz);
188 if(!p) {
189 slog(c->s, WEBDIS_ERROR, "Failed to allocate buffer for WS handshake", 0);
190 return -1;
191 }
192
193 /* Concat all */
194
195 /* template_start */
196 memcpy(p, template_start, sizeof(template_start)-1);
197 p += sizeof(template_start)-1;
198
199 /* template_accept */
200 memcpy(p, template_accept, sizeof(template_accept)-1);
201 p += sizeof(template_accept)-1;
202 memcpy(p, &sha1_handshake[0], handshake_sz);
203 p += handshake_sz;
204
205 /* template_sec_origin */
206 if(origin && origin_sz) {
207 memcpy(p, template_sec_origin, sizeof(template_sec_origin)-1);
208 p += sizeof(template_sec_origin)-1;
209 memcpy(p, origin, origin_sz);
210 p += origin_sz;
211 }
212
213 /* template_loc */
214 memcpy(p, template_loc, sizeof(template_loc)-1);
215 p += sizeof(template_loc)-1;
216 memcpy(p, host, host_sz);
217 p += host_sz;
218 memcpy(p, c->path, c->path_sz);
219 p += c->path_sz;
220
221 /* template_end */
222 memcpy(p, template_end, sizeof(template_end)-1);
223 p += sizeof(template_end)-1;
224
225 int add_ret = evbuffer_add(ws->wbuf, buffer, sz);
226 free(buffer);
227 if(add_ret < 0) {
228 slog(c->s, WEBDIS_ERROR, "Failed to add response for WS handshake", 0);
229 return -1;
230 }
231
232 return ws_schedule_write(ws); /* will free buffer and response once sent */
233 }
234
235 static void
ws_log_cmd(struct ws_client * ws,struct cmd * cmd)236 ws_log_cmd(struct ws_client *ws, struct cmd *cmd) {
237 char log_msg[SLOG_MSG_MAX_LEN];
238 char *p = log_msg, *eom = log_msg + sizeof(log_msg) - 1;
239 if(!slog_enabled(ws->http_client->s, WEBDIS_DEBUG)) {
240 return;
241 }
242
243 memset(log_msg, 0, sizeof(log_msg));
244 memcpy(p, "WS: ", 4); /* WS prefix */
245 p += 4;
246
247 for(int i = 0; p < eom && i < cmd->count; i++) {
248 *p++ = '/';
249 char *arg = cmd->argv[i];
250 size_t arg_sz = cmd->argv_len[i];
251 size_t copy_sz = arg_sz < (size_t)(eom - p) ? arg_sz : (size_t)(eom - p);
252 memcpy(p, arg, copy_sz);
253 p += copy_sz;
254 }
255 slog(ws->http_client->s, WEBDIS_DEBUG, log_msg, p - log_msg);
256 }
257
258
259 static int
ws_execute(struct ws_client * ws,struct ws_msg * msg)260 ws_execute(struct ws_client *ws, struct ws_msg *msg) {
261
262 struct http_client *c = ws->http_client;
263 struct cmd*(*fun_extract)(struct http_client *, const char *, size_t) = NULL;
264 formatting_fun fun_reply = NULL;
265
266 if((c->path_sz == 1 && strncmp(c->path, "/", 1) == 0) ||
267 strncmp(c->path, "/.json", 6) == 0) {
268 fun_extract = json_ws_extract;
269 fun_reply = json_reply;
270 } else if(strncmp(c->path, "/.raw", 5) == 0) {
271 fun_extract = raw_ws_extract;
272 fun_reply = raw_reply;
273 }
274
275 if(fun_extract) {
276
277 /* Parse websocket frame into a cmd object. */
278 struct cmd *cmd = fun_extract(c, msg->payload, msg->payload_sz);
279
280 if(cmd) {
281 cmd->is_websocket = 1;
282
283 if(ws->cmd != NULL) {
284 /* This client already has its own connection to Redis
285 from a previous command; use it from now on. */
286
287 /* free args for the previous cmd */
288 cmd_free_argv(ws->cmd);
289 /* copy args from what we just parsed to the persistent command */
290 ws->cmd->count = cmd->count;
291 ws->cmd->argv = cmd->argv;
292 ws->cmd->argv_len = cmd->argv_len;
293 ws->cmd->pub_sub_client = c; /* mark as persistent, otherwise the Redis context will be freed */
294
295 cmd->argv = NULL;
296 cmd->argv_len = NULL;
297 cmd->count = 0;
298 cmd_free(cmd);
299
300 cmd = ws->cmd; /* replace pointer since we're about to pass it to cmd_send */
301 } else {
302 /* copy client info into cmd. */
303 cmd_setup(cmd, c);
304
305 /* First WS command; use Redis context from WS client. */
306 cmd->ac = ws->ac;
307 ws->cmd = cmd;
308 cmd->pub_sub_client = c;
309 }
310
311 int is_subscribe = cmd_is_subscribe_args(cmd);
312 int is_unsubscribe = cmd_is_unsubscribe_args(cmd);
313
314 if(ws->ran_subscribe && !is_subscribe && !is_unsubscribe) { /* disallow non-subscribe commands after a subscribe */
315 char error_msg[] = "Command not allowed after subscribe";
316 ws_frame_and_send_response(ws, WS_BINARY_FRAME, error_msg, sizeof(error_msg)-1);
317 } else { /* log and execute */
318 ws_log_cmd(ws, cmd);
319 cmd_send(cmd, fun_reply);
320 ws->ran_subscribe = is_subscribe;
321 }
322
323 return 0;
324 }
325 }
326
327 return -1;
328 }
329
330 static struct ws_msg *
ws_msg_new(enum ws_frame_type frame_type)331 ws_msg_new(enum ws_frame_type frame_type) {
332 struct ws_msg *msg = calloc(1, sizeof(struct ws_msg));
333 if(!msg) {
334 return NULL;
335 }
336 msg->type = frame_type;
337 return msg;
338 }
339
340 static int
ws_msg_add(struct ws_msg * m,const char * p,size_t psz,const unsigned char * mask)341 ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mask) {
342
343 /* add data to frame */
344 size_t i;
345 m->payload = realloc(m->payload, m->payload_sz + psz);
346 if(!m->payload) {
347 return -1;
348 }
349 memcpy(m->payload + m->payload_sz, p, psz);
350
351 /* apply mask */
352 for(i = 0; i < psz && mask; ++i) {
353 m->payload[m->payload_sz + i] = (unsigned char)p[i] ^ mask[i%4];
354 }
355
356 /* save new size */
357 m->payload_sz += psz;
358 return 0;
359 }
360
361 static void
ws_msg_free(struct ws_msg * m)362 ws_msg_free(struct ws_msg *m) {
363
364 free(m->payload);
365 free(m);
366 }
367
368 /* checks to see if we have a complete message */
369 static enum ws_state
ws_peek_data(struct ws_client * ws,struct ws_msg ** out_msg)370 ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
371
372 int has_mask;
373 uint64_t len;
374 const char *p;
375 char *frame;
376 unsigned char mask[4];
377 char fin_bit_set;
378 enum ws_frame_type frame_type;
379
380 /* parse frame and extract contents */
381 size_t sz = evbuffer_get_length(ws->rbuf);
382 if(sz < 8) {
383 return WS_READING; /* need more data */
384 }
385 /* copy into "frame" to process it */
386 frame = malloc(sz);
387 if(!frame) {
388 return WS_ERROR;
389 }
390 int rem_ret = evbuffer_remove(ws->rbuf, frame, sz);
391 if(rem_ret < 0) {
392 free(frame);
393 return WS_ERROR;
394 }
395
396 fin_bit_set = frame[0] & 0x80 ? 1 : 0;
397 frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */
398 has_mask = frame[1] & 0x80 ? 1:0;
399
400 /* get payload length */
401 len = frame[1] & 0x7f; /* remove leftmost bit */
402 if(len <= 125) { /* data starts right after the mask */
403 p = frame + 2 + (has_mask ? 4 : 0);
404 if(has_mask) memcpy(&mask, frame + 2, sizeof(mask));
405 } else if(len == 126) {
406 uint16_t sz16;
407 memcpy(&sz16, frame + 2, sizeof(uint16_t));
408 len = ntohs(sz16);
409 p = frame + 4 + (has_mask ? 4 : 0);
410 if(has_mask) memcpy(&mask, frame + 4, sizeof(mask));
411 } else if(len == 127) {
412 uint64_t sz64 = *((uint64_t*)(frame+2));
413 len = webdis_ntohll(sz64);
414 p = frame + 10 + (has_mask ? 4 : 0);
415 if(has_mask) memcpy(&mask, frame + 10, sizeof(mask));
416 } else {
417 free(frame);
418 return WS_ERROR;
419 }
420
421 /* we now have the (possibly masked) data starting in p, and its length. */
422 if(len > sz - (p - frame)) { /* not enough data */
423 int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
424 free(frame);
425 return add_ret < 0 ? WS_ERROR : WS_READING;
426 }
427
428 int ev_copy = 0;
429 if(out_msg) { /* we're extracting the message */
430 struct ws_msg *msg = ws_msg_new(frame_type);
431 if(!msg) {
432 free(frame);
433 return WS_ERROR;
434 }
435 *out_msg = msg; /* attach for it to be freed by caller */
436
437 /* create new ws_msg object holding what we read */
438 int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL);
439 if(!add_ret) {
440 free(frame);
441 return WS_ERROR;
442 }
443
444 size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */
445 msg->total_sz += processed_sz;
446
447 ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */
448 } else { /* we're just peeking */
449 ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
450 }
451 free(frame);
452
453 if(ev_copy < 0) {
454 return WS_ERROR;
455 } else if(fin_bit_set) {
456 return WS_MSG_COMPLETE;
457 } else {
458 return WS_READING; /* need more data */
459 }
460 }
461
462 /**
463 * Process some data just received on the socket.
464 * Returns the number of messages processed in out_processed, if non-NULL.
465 */
466 enum ws_state
ws_process_read_data(struct ws_client * ws,unsigned int * out_processed)467 ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) {
468
469 enum ws_state state;
470 if(out_processed) *out_processed = 0;
471
472 state = ws_peek_data(ws, NULL); /* check for complete message */
473
474 while(state == WS_MSG_COMPLETE) {
475 int ret = 0;
476 struct ws_msg *msg = NULL;
477 ws_peek_data(ws, &msg); /* extract message */
478
479 if(msg && (msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME)) {
480 ret = ws_execute(ws, msg);
481 if(out_processed) (*out_processed)++;
482 } else if(msg && msg->type == WS_PING) { /* respond to ping */
483 ws_frame_and_send_response(ws, WS_PONG, msg->payload, msg->payload_sz);
484 } else if(msg && msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */
485 ws->close_after_events = 1;
486 ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, msg->payload, msg->payload_sz);
487 } else if(msg) {
488 char format[] = "Received unexpected WS frame type: 0x%x";
489 char error[(sizeof format)];
490 int error_len = snprintf(error, sizeof(error), format, msg->type);
491 slog(ws->http_client->s, WEBDIS_WARNING, error, error_len);
492 }
493
494 /* free frame */
495 if(msg) ws_msg_free(msg);
496
497 if(ret != 0) {
498 /* can't process frame. */
499 slog(ws->http_client->s, WEBDIS_DEBUG, "ws_process_read_data: ws_execute failed", 0);
500 return WS_ERROR;
501 }
502 state = ws_peek_data(ws, NULL);
503 }
504 return state;
505 }
506
507 int
ws_frame_and_send_response(struct ws_client * ws,enum ws_frame_type frame_type,const char * p,size_t sz)508 ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, const char *p, size_t sz) {
509
510 /* we can have as much as 14 bytes in the header:
511 * 1 byte for 4 flag bits + 4 frame type bits
512 * 1 byte for the payload length indicator
513 * 8 bytes for the size of the payload (at most)
514 * 4 bytes for the masking key (if present)
515 */
516 char *frame = malloc(sz + 14); /* create frame by prepending header */
517 size_t frame_sz = 0;
518 if(frame == NULL)
519 return -1;
520
521 /*
522 The length of the "Payload data", in bytes: if 0-125, that is the
523 payload length. If 126, the following 2 bytes interpreted as a
524 16-bit unsigned integer are the payload length. If 127, the
525 following 8 bytes interpreted as a 64-bit unsigned integer (the
526 most significant bit MUST be 0) are the payload length.
527 */
528 frame[0] = 0x80 | frame_type; /* frame type + EOM bit */
529 if(sz <= 125) {
530 frame[1] = sz;
531 memcpy(frame + 2, p, sz);
532 frame_sz = sz + 2;
533 } else if(sz <= 65536) {
534 uint16_t sz16 = htons(sz);
535 frame[1] = 126;
536 memcpy(frame + 2, &sz16, 2);
537 memcpy(frame + 4, p, sz);
538 frame_sz = sz + 4;
539 } else { /* sz > 65536 */
540 uint64_t sz_be = webdis_htonll(sz); /* big endian */
541 char sz64[8];
542 memcpy(sz64, &sz_be, 8);
543 frame[1] = 127;
544 memcpy(frame + 2, sz64, 8);
545 memcpy(frame + 10, p, sz);
546 frame_sz = sz + 10;
547 }
548
549 /* mark as keep alive, otherwise we'll close the connection after the first reply */
550 int add_ret = evbuffer_add(ws->wbuf, frame, frame_sz);
551 free(frame); /* no longer needed once added to buffer */
552 if(add_ret < 0) {
553 slog(ws->http_client->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0);
554 return -1;
555 }
556
557 /* send WS frame */
558 return ws_schedule_write(ws);
559 }
560
561 static void
ws_close_if_able(struct ws_client * ws)562 ws_close_if_able(struct ws_client *ws) {
563
564 ws->close_after_events = 1; /* note that we're closing */
565 if(ws->scheduled_read || ws->scheduled_write) {
566 return; /* still waiting for these events to trigger */
567 }
568 ws_client_free(ws); /* will close the socket */
569 }
570
571 static void
ws_can_read(int fd,short event,void * p)572 ws_can_read(int fd, short event, void *p) {
573
574 int ret;
575 struct ws_client *ws = p;
576 (void)event;
577
578 /* read pending data */
579 ws->scheduled_read = 0;
580 ret = evbuffer_read(ws->rbuf, fd, 4096);
581
582 if(ret <= 0) {
583 ws_client_free(ws); /* will close the socket */
584 } else if(ws->close_after_events) {
585 ws_close_if_able(ws);
586 } else {
587 enum ws_state state = ws_process_read_data(ws, NULL);
588 if(state == WS_READING) { /* need more data, schedule new read */
589 ws_monitor_input(ws);
590 } else if(state == WS_ERROR) {
591 ws_close_if_able(ws);
592 }
593 }
594 }
595
596
597 static void
ws_can_write(int fd,short event,void * p)598 ws_can_write(int fd, short event, void *p) {
599
600 int ret;
601 struct ws_client *ws = p;
602 (void)event;
603
604 ws->scheduled_write = 0;
605
606 /* send pending data */
607 ret = evbuffer_write_atmost(ws->wbuf, fd, 4096);
608
609 if(ret <= 0) {
610 ws_client_free(ws); /* will close the socket */
611 } else {
612 if(evbuffer_get_length(ws->wbuf) > 0) { /* more data to send */
613 ws_schedule_write(ws);
614 } else if(ws->close_after_events) { /* we're done! */
615 ws_close_if_able(ws);
616 } else {
617 /* check if we can read more data */
618 unsigned int processed = 0;
619 ws_process_read_data(ws, &processed); /* process any pending data we've already read */
620 ws_monitor_input(ws); /* let's read more from the client */
621 }
622 }
623 }
624
625 static int
ws_schedule_write(struct ws_client * ws)626 ws_schedule_write(struct ws_client *ws) {
627 struct http_client *c = ws->http_client;
628 if(!ws->scheduled_write) {
629 ws->scheduled_write = 1;
630 return event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL);
631 }
632 return 0;
633 }
634
635 int
ws_monitor_input(struct ws_client * ws)636 ws_monitor_input(struct ws_client *ws) {
637 struct http_client *c = ws->http_client;
638 if(!ws->scheduled_read) {
639 ws->scheduled_read = 1;
640 return event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL);
641 }
642 return 0;
643 }
644