1 /*
2 Copyright (c) 2007-2019 Contributors as noted in the AUTHORS file
3
4 This file is part of libzmq, the ZeroMQ core engine in C++.
5
6 libzmq is free software; you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 3 of the License, or
9 (at your option) any later version.
10
11 As a special exception, the Contributors give you permission to link
12 this library with independent modules to produce an executable,
13 regardless of the license terms of these independent modules, and to
14 copy and distribute the resulting executable under terms of your choice,
15 provided that you also meet, for each linked independent module, the
16 terms and conditions of the license of that module. An independent
17 module is a module which is not derived from or based on this library.
18 If you modify this library, you must extend this exception to your
19 version of the library.
20
21 libzmq is distributed in the hope that it will be useful, but WITHOUT
22 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
23 FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
24 License for more details.
25
26 You should have received a copy of the GNU Lesser General Public License
27 along with this program. If not, see <http://www.gnu.org/licenses/>.
28 */
29
30 #include "precompiled.hpp"
31
32 #ifdef ZMQ_USE_NSS
33 #include <secoid.h>
34 #include <sechash.h>
35 #define SHA_DIGEST_LENGTH 20
36 #elif defined ZMQ_USE_BUILTIN_SHA1
37 #include "../external/sha1/sha1.h"
38 #elif defined ZMQ_USE_GNUTLS
39 #define SHA_DIGEST_LENGTH 20
40 #include <gnutls/gnutls.h>
41 #include <gnutls/crypto.h>
42 #endif
43
44 #if !defined ZMQ_HAVE_WINDOWS
45 #include <sys/types.h>
46 #include <unistd.h>
47 #include <sys/socket.h>
48 #include <netinet/in.h>
49 #include <arpa/inet.h>
50 #ifdef ZMQ_HAVE_VXWORKS
51 #include <sockLib.h>
52 #endif
53 #endif
54
55 #include <cstring>
56
57 #include "compat.hpp"
58 #include "tcp.hpp"
59 #include "ws_engine.hpp"
60 #include "session_base.hpp"
61 #include "err.hpp"
62 #include "ip.hpp"
63 #include "random.hpp"
64 #include "ws_decoder.hpp"
65 #include "ws_encoder.hpp"
66 #include "null_mechanism.hpp"
67 #include "plain_server.hpp"
68 #include "plain_client.hpp"
69
70 #ifdef ZMQ_HAVE_CURVE
71 #include "curve_client.hpp"
72 #include "curve_server.hpp"
73 #endif
74
75 // OSX uses a different name for this socket option
76 #ifndef IPV6_ADD_MEMBERSHIP
77 #define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP
78 #endif
79
80 #ifdef __APPLE__
81 #include <TargetConditionals.h>
82 #endif
83
84 static int
85 encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_);
86
87 static void compute_accept_key (char *key_,
88 unsigned char hash_[SHA_DIGEST_LENGTH]);
89
ws_engine_t(fd_t fd_,const options_t & options_,const endpoint_uri_pair_t & endpoint_uri_pair_,const ws_address_t & address_,bool client_)90 zmq::ws_engine_t::ws_engine_t (fd_t fd_,
91 const options_t &options_,
92 const endpoint_uri_pair_t &endpoint_uri_pair_,
93 const ws_address_t &address_,
94 bool client_) :
95 stream_engine_base_t (fd_, options_, endpoint_uri_pair_, true),
96 _client (client_),
97 _address (address_),
98 _client_handshake_state (client_handshake_initial),
99 _server_handshake_state (handshake_initial),
100 _header_name_position (0),
101 _header_value_position (0),
102 _header_upgrade_websocket (false),
103 _header_connection_upgrade (false),
104 _heartbeat_timeout (0)
105 {
106 memset (_websocket_key, 0, MAX_HEADER_VALUE_LENGTH + 1);
107 memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
108 memset (_websocket_protocol, 0, 256);
109
110 _next_msg = &ws_engine_t::next_handshake_command;
111 _process_msg = &ws_engine_t::process_handshake_command;
112 _close_msg.init ();
113
114 if (_options.heartbeat_interval > 0) {
115 _heartbeat_timeout = _options.heartbeat_timeout;
116 if (_heartbeat_timeout == -1)
117 _heartbeat_timeout = _options.heartbeat_interval;
118 }
119 }
120
~ws_engine_t()121 zmq::ws_engine_t::~ws_engine_t ()
122 {
123 _close_msg.close ();
124 }
125
start_ws_handshake()126 void zmq::ws_engine_t::start_ws_handshake ()
127 {
128 if (_client) {
129 const char *protocol;
130 if (_options.mechanism == ZMQ_NULL)
131 protocol = "ZWS2.0/NULL,ZWS2.0";
132 else if (_options.mechanism == ZMQ_PLAIN)
133 protocol = "ZWS2.0/PLAIN";
134 #ifdef ZMQ_HAVE_CURVE
135 else if (_options.mechanism == ZMQ_CURVE)
136 protocol = "ZWS2.0/CURVE";
137 #endif
138 else {
139 // Avoid unitialized variable error breaking UWP build
140 protocol = "";
141 assert (false);
142 }
143
144 unsigned char nonce[16];
145 int *p = reinterpret_cast<int *> (nonce);
146
147 // The nonce doesn't have to be secure one, it is just use to avoid proxy cache
148 *p = zmq::generate_random ();
149 *(p + 1) = zmq::generate_random ();
150 *(p + 2) = zmq::generate_random ();
151 *(p + 3) = zmq::generate_random ();
152
153 int size =
154 encode_base64 (nonce, 16, _websocket_key, MAX_HEADER_VALUE_LENGTH);
155 assert (size > 0);
156
157 size = snprintf (
158 reinterpret_cast<char *> (_write_buffer), WS_BUFFER_SIZE,
159 "GET %s HTTP/1.1\r\n"
160 "Host: %s\r\n"
161 "Upgrade: websocket\r\n"
162 "Connection: Upgrade\r\n"
163 "Sec-WebSocket-Key: %s\r\n"
164 "Sec-WebSocket-Protocol: %s\r\n"
165 "Sec-WebSocket-Version: 13\r\n\r\n",
166 _address.path (), _address.host (), _websocket_key, protocol);
167 assert (size > 0 && size < WS_BUFFER_SIZE);
168 _outpos = _write_buffer;
169 _outsize = size;
170 set_pollout ();
171 }
172 }
173
plug_internal()174 void zmq::ws_engine_t::plug_internal ()
175 {
176 start_ws_handshake ();
177 set_pollin ();
178 in_event ();
179 }
180
routing_id_msg(msg_t * msg_)181 int zmq::ws_engine_t::routing_id_msg (msg_t *msg_)
182 {
183 const int rc = msg_->init_size (_options.routing_id_size);
184 errno_assert (rc == 0);
185 if (_options.routing_id_size > 0)
186 memcpy (msg_->data (), _options.routing_id, _options.routing_id_size);
187 _next_msg = &ws_engine_t::pull_msg_from_session;
188
189 return 0;
190 }
191
process_routing_id_msg(msg_t * msg_)192 int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
193 {
194 if (_options.recv_routing_id) {
195 msg_->set_flags (msg_t::routing_id);
196 const int rc = session ()->push_msg (msg_);
197 errno_assert (rc == 0);
198 } else {
199 int rc = msg_->close ();
200 errno_assert (rc == 0);
201 rc = msg_->init ();
202 errno_assert (rc == 0);
203 }
204
205 _process_msg = &ws_engine_t::push_msg_to_session;
206
207 return 0;
208 }
209
select_protocol(const char * protocol_)210 bool zmq::ws_engine_t::select_protocol (const char *protocol_)
211 {
212 if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol_) == 0)) {
213 _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
214 &ws_engine_t::routing_id_msg);
215 _process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
216 &ws_engine_t::process_routing_id_msg);
217
218 // No mechanism in place, enabling heartbeat
219 if (_options.heartbeat_interval > 0 && !_has_heartbeat_timer) {
220 add_timer (_options.heartbeat_interval, heartbeat_ivl_timer_id);
221 _has_heartbeat_timer = true;
222 }
223
224 return true;
225 }
226 if (_options.mechanism == ZMQ_NULL
227 && strcmp ("ZWS2.0/NULL", protocol_) == 0) {
228 _mechanism = new (std::nothrow)
229 null_mechanism_t (session (), _peer_address, _options);
230 alloc_assert (_mechanism);
231 return true;
232 } else if (_options.mechanism == ZMQ_PLAIN
233 && strcmp ("ZWS2.0/PLAIN", protocol_) == 0) {
234 if (_options.as_server)
235 _mechanism = new (std::nothrow)
236 plain_server_t (session (), _peer_address, _options);
237 else
238 _mechanism =
239 new (std::nothrow) plain_client_t (session (), _options);
240 alloc_assert (_mechanism);
241 return true;
242 }
243 #ifdef ZMQ_HAVE_CURVE
244 else if (_options.mechanism == ZMQ_CURVE
245 && strcmp ("ZWS2.0/CURVE", protocol_) == 0) {
246 if (_options.as_server)
247 _mechanism = new (std::nothrow)
248 curve_server_t (session (), _peer_address, _options, false);
249 else
250 _mechanism =
251 new (std::nothrow) curve_client_t (session (), _options, false);
252 alloc_assert (_mechanism);
253 return true;
254 }
255 #endif
256
257 return false;
258 }
259
handshake()260 bool zmq::ws_engine_t::handshake ()
261 {
262 bool complete;
263
264 if (_client)
265 complete = client_handshake ();
266 else
267 complete = server_handshake ();
268
269 if (complete) {
270 _encoder =
271 new (std::nothrow) ws_encoder_t (_options.out_batch_size, _client);
272 alloc_assert (_encoder);
273
274 _decoder = new (std::nothrow)
275 ws_decoder_t (_options.in_batch_size, _options.maxmsgsize,
276 _options.zero_copy, !_client);
277 alloc_assert (_decoder);
278
279 socket ()->event_handshake_succeeded (_endpoint_uri_pair, 0);
280
281 set_pollout ();
282 }
283
284 return complete;
285 }
286
server_handshake()287 bool zmq::ws_engine_t::server_handshake ()
288 {
289 const int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
290 if (nbytes == -1) {
291 if (errno != EAGAIN)
292 error (zmq::i_engine::connection_error);
293 return false;
294 }
295
296 _inpos = _read_buffer;
297 _insize = nbytes;
298
299 while (_insize > 0) {
300 const char c = static_cast<char> (*_inpos);
301
302 switch (_server_handshake_state) {
303 case handshake_initial:
304 if (c == 'G')
305 _server_handshake_state = request_line_G;
306 else
307 _server_handshake_state = handshake_error;
308 break;
309 case request_line_G:
310 if (c == 'E')
311 _server_handshake_state = request_line_GE;
312 else
313 _server_handshake_state = handshake_error;
314 break;
315 case request_line_GE:
316 if (c == 'T')
317 _server_handshake_state = request_line_GET;
318 else
319 _server_handshake_state = handshake_error;
320 break;
321 case request_line_GET:
322 if (c == ' ')
323 _server_handshake_state = request_line_GET_space;
324 else
325 _server_handshake_state = handshake_error;
326 break;
327 case request_line_GET_space:
328 if (c == '\r' || c == '\n')
329 _server_handshake_state = handshake_error;
330 // TODO: instead of check what is not allowed check what is allowed
331 if (c != ' ')
332 _server_handshake_state = request_line_resource;
333 else
334 _server_handshake_state = request_line_GET_space;
335 break;
336 case request_line_resource:
337 if (c == '\r' || c == '\n')
338 _server_handshake_state = handshake_error;
339 else if (c == ' ')
340 _server_handshake_state = request_line_resource_space;
341 else
342 _server_handshake_state = request_line_resource;
343 break;
344 case request_line_resource_space:
345 if (c == 'H')
346 _server_handshake_state = request_line_H;
347 else
348 _server_handshake_state = handshake_error;
349 break;
350 case request_line_H:
351 if (c == 'T')
352 _server_handshake_state = request_line_HT;
353 else
354 _server_handshake_state = handshake_error;
355 break;
356 case request_line_HT:
357 if (c == 'T')
358 _server_handshake_state = request_line_HTT;
359 else
360 _server_handshake_state = handshake_error;
361 break;
362 case request_line_HTT:
363 if (c == 'P')
364 _server_handshake_state = request_line_HTTP;
365 else
366 _server_handshake_state = handshake_error;
367 break;
368 case request_line_HTTP:
369 if (c == '/')
370 _server_handshake_state = request_line_HTTP_slash;
371 else
372 _server_handshake_state = handshake_error;
373 break;
374 case request_line_HTTP_slash:
375 if (c == '1')
376 _server_handshake_state = request_line_HTTP_slash_1;
377 else
378 _server_handshake_state = handshake_error;
379 break;
380 case request_line_HTTP_slash_1:
381 if (c == '.')
382 _server_handshake_state = request_line_HTTP_slash_1_dot;
383 else
384 _server_handshake_state = handshake_error;
385 break;
386 case request_line_HTTP_slash_1_dot:
387 if (c == '1')
388 _server_handshake_state = request_line_HTTP_slash_1_dot_1;
389 else
390 _server_handshake_state = handshake_error;
391 break;
392 case request_line_HTTP_slash_1_dot_1:
393 if (c == '\r')
394 _server_handshake_state = request_line_cr;
395 else
396 _server_handshake_state = handshake_error;
397 break;
398 case request_line_cr:
399 if (c == '\n')
400 _server_handshake_state = header_field_begin_name;
401 else
402 _server_handshake_state = handshake_error;
403 break;
404 case header_field_begin_name:
405 switch (c) {
406 case '\r':
407 _server_handshake_state = handshake_end_line_cr;
408 break;
409 case '\n':
410 _server_handshake_state = handshake_error;
411 break;
412 default:
413 _header_name[0] = c;
414 _header_name_position = 1;
415 _server_handshake_state = header_field_name;
416 break;
417 }
418 break;
419 case header_field_name:
420 if (c == '\r' || c == '\n')
421 _server_handshake_state = handshake_error;
422 else if (c == ':') {
423 _header_name[_header_name_position] = '\0';
424 _server_handshake_state = header_field_colon;
425 } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
426 _server_handshake_state = handshake_error;
427 else {
428 _header_name[_header_name_position] = c;
429 _header_name_position++;
430 _server_handshake_state = header_field_name;
431 }
432 break;
433 case header_field_colon:
434 case header_field_value_trailing_space:
435 if (c == '\n')
436 _server_handshake_state = handshake_error;
437 else if (c == '\r')
438 _server_handshake_state = header_field_cr;
439 else if (c == ' ')
440 _server_handshake_state = header_field_value_trailing_space;
441 else {
442 _header_value[0] = c;
443 _header_value_position = 1;
444 _server_handshake_state = header_field_value;
445 }
446 break;
447 case header_field_value:
448 if (c == '\n')
449 _server_handshake_state = handshake_error;
450 else if (c == '\r') {
451 _header_value[_header_value_position] = '\0';
452
453 if (strcasecmp ("upgrade", _header_name) == 0)
454 _header_upgrade_websocket =
455 strcasecmp ("websocket", _header_value) == 0;
456 else if (strcasecmp ("connection", _header_name) == 0)
457 _header_connection_upgrade =
458 strcasecmp ("upgrade", _header_value) == 0;
459 else if (strcasecmp ("Sec-WebSocket-Key", _header_name)
460 == 0)
461 strcpy_s (_websocket_key, _header_value);
462 else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
463 == 0) {
464 // Currently only the ZWS2.0 is supported
465 // Sec-WebSocket-Protocol can appear multiple times or be a comma separated list
466 // if _websocket_protocol is already set we skip the check
467 if (_websocket_protocol[0] == '\0') {
468 char *rest = 0;
469 char *p = strtok_r (_header_value, ",", &rest);
470 while (p != NULL) {
471 if (*p == ' ')
472 p++;
473
474 if (select_protocol (p)) {
475 strcpy_s (_websocket_protocol, p);
476 break;
477 }
478
479 p = strtok_r (NULL, ",", &rest);
480 }
481 }
482 }
483
484 _server_handshake_state = header_field_cr;
485 } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
486 _server_handshake_state = handshake_error;
487 else {
488 _header_value[_header_value_position] = c;
489 _header_value_position++;
490 _server_handshake_state = header_field_value;
491 }
492 break;
493 case header_field_cr:
494 if (c == '\n')
495 _server_handshake_state = header_field_begin_name;
496 else
497 _server_handshake_state = handshake_error;
498 break;
499 case handshake_end_line_cr:
500 if (c == '\n') {
501 if (_header_connection_upgrade && _header_upgrade_websocket
502 && _websocket_protocol[0] != '\0'
503 && _websocket_key[0] != '\0') {
504 _server_handshake_state = handshake_complete;
505
506 unsigned char hash[SHA_DIGEST_LENGTH];
507 compute_accept_key (_websocket_key, hash);
508
509 const int accept_key_len = encode_base64 (
510 hash, SHA_DIGEST_LENGTH, _websocket_accept,
511 MAX_HEADER_VALUE_LENGTH);
512 assert (accept_key_len > 0);
513 _websocket_accept[accept_key_len] = '\0';
514
515 const int written =
516 snprintf (reinterpret_cast<char *> (_write_buffer),
517 WS_BUFFER_SIZE,
518 "HTTP/1.1 101 Switching Protocols\r\n"
519 "Upgrade: websocket\r\n"
520 "Connection: Upgrade\r\n"
521 "Sec-WebSocket-Accept: %s\r\n"
522 "Sec-WebSocket-Protocol: %s\r\n"
523 "\r\n",
524 _websocket_accept, _websocket_protocol);
525 assert (written >= 0 && written < WS_BUFFER_SIZE);
526 _outpos = _write_buffer;
527 _outsize = written;
528
529 _inpos++;
530 _insize--;
531
532 return true;
533 }
534 _server_handshake_state = handshake_error;
535 } else
536 _server_handshake_state = handshake_error;
537 break;
538 default:
539 assert (false);
540 }
541
542 _inpos++;
543 _insize--;
544
545 if (_server_handshake_state == handshake_error) {
546 // TODO: send bad request
547
548 socket ()->event_handshake_failed_protocol (
549 _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
550
551 error (zmq::i_engine::protocol_error);
552 return false;
553 }
554 }
555 return false;
556 }
557
client_handshake()558 bool zmq::ws_engine_t::client_handshake ()
559 {
560 const int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
561 if (nbytes == -1) {
562 if (errno != EAGAIN)
563 error (zmq::i_engine::connection_error);
564 return false;
565 }
566
567 _inpos = _read_buffer;
568 _insize = nbytes;
569
570 while (_insize > 0) {
571 const char c = static_cast<char> (*_inpos);
572
573 switch (_client_handshake_state) {
574 case client_handshake_initial:
575 if (c == 'H')
576 _client_handshake_state = response_line_H;
577 else
578 _client_handshake_state = client_handshake_error;
579 break;
580 case response_line_H:
581 if (c == 'T')
582 _client_handshake_state = response_line_HT;
583 else
584 _client_handshake_state = client_handshake_error;
585 break;
586 case response_line_HT:
587 if (c == 'T')
588 _client_handshake_state = response_line_HTT;
589 else
590 _client_handshake_state = client_handshake_error;
591 break;
592 case response_line_HTT:
593 if (c == 'P')
594 _client_handshake_state = response_line_HTTP;
595 else
596 _client_handshake_state = client_handshake_error;
597 break;
598 case response_line_HTTP:
599 if (c == '/')
600 _client_handshake_state = response_line_HTTP_slash;
601 else
602 _client_handshake_state = client_handshake_error;
603 break;
604 case response_line_HTTP_slash:
605 if (c == '1')
606 _client_handshake_state = response_line_HTTP_slash_1;
607 else
608 _client_handshake_state = client_handshake_error;
609 break;
610 case response_line_HTTP_slash_1:
611 if (c == '.')
612 _client_handshake_state = response_line_HTTP_slash_1_dot;
613 else
614 _client_handshake_state = client_handshake_error;
615 break;
616 case response_line_HTTP_slash_1_dot:
617 if (c == '1')
618 _client_handshake_state = response_line_HTTP_slash_1_dot_1;
619 else
620 _client_handshake_state = client_handshake_error;
621 break;
622 case response_line_HTTP_slash_1_dot_1:
623 if (c == ' ')
624 _client_handshake_state =
625 response_line_HTTP_slash_1_dot_1_space;
626 else
627 _client_handshake_state = client_handshake_error;
628 break;
629 case response_line_HTTP_slash_1_dot_1_space:
630 if (c == ' ')
631 _client_handshake_state =
632 response_line_HTTP_slash_1_dot_1_space;
633 else if (c == '1')
634 _client_handshake_state = response_line_status_1;
635 else
636 _client_handshake_state = client_handshake_error;
637 break;
638 case response_line_status_1:
639 if (c == '0')
640 _client_handshake_state = response_line_status_10;
641 else
642 _client_handshake_state = client_handshake_error;
643 break;
644 case response_line_status_10:
645 if (c == '1')
646 _client_handshake_state = response_line_status_101;
647 else
648 _client_handshake_state = client_handshake_error;
649 break;
650 case response_line_status_101:
651 if (c == ' ')
652 _client_handshake_state = response_line_status_101_space;
653 else
654 _client_handshake_state = client_handshake_error;
655 break;
656 case response_line_status_101_space:
657 if (c == ' ')
658 _client_handshake_state = response_line_status_101_space;
659 else if (c == 'S')
660 _client_handshake_state = response_line_s;
661 else
662 _client_handshake_state = client_handshake_error;
663 break;
664 case response_line_s:
665 if (c == 'w')
666 _client_handshake_state = response_line_sw;
667 else
668 _client_handshake_state = client_handshake_error;
669 break;
670 case response_line_sw:
671 if (c == 'i')
672 _client_handshake_state = response_line_swi;
673 else
674 _client_handshake_state = client_handshake_error;
675 break;
676 case response_line_swi:
677 if (c == 't')
678 _client_handshake_state = response_line_swit;
679 else
680 _client_handshake_state = client_handshake_error;
681 break;
682 case response_line_swit:
683 if (c == 'c')
684 _client_handshake_state = response_line_switc;
685 else
686 _client_handshake_state = client_handshake_error;
687 break;
688 case response_line_switc:
689 if (c == 'h')
690 _client_handshake_state = response_line_switch;
691 else
692 _client_handshake_state = client_handshake_error;
693 break;
694 case response_line_switch:
695 if (c == 'i')
696 _client_handshake_state = response_line_switchi;
697 else
698 _client_handshake_state = client_handshake_error;
699 break;
700 case response_line_switchi:
701 if (c == 'n')
702 _client_handshake_state = response_line_switchin;
703 else
704 _client_handshake_state = client_handshake_error;
705 break;
706 case response_line_switchin:
707 if (c == 'g')
708 _client_handshake_state = response_line_switching;
709 else
710 _client_handshake_state = client_handshake_error;
711 break;
712 case response_line_switching:
713 if (c == ' ')
714 _client_handshake_state = response_line_switching_space;
715 else
716 _client_handshake_state = client_handshake_error;
717 break;
718 case response_line_switching_space:
719 if (c == 'P')
720 _client_handshake_state = response_line_p;
721 else
722 _client_handshake_state = client_handshake_error;
723 break;
724 case response_line_p:
725 if (c == 'r')
726 _client_handshake_state = response_line_pr;
727 else
728 _client_handshake_state = client_handshake_error;
729 break;
730 case response_line_pr:
731 if (c == 'o')
732 _client_handshake_state = response_line_pro;
733 else
734 _client_handshake_state = client_handshake_error;
735 break;
736 case response_line_pro:
737 if (c == 't')
738 _client_handshake_state = response_line_prot;
739 else
740 _client_handshake_state = client_handshake_error;
741 break;
742 case response_line_prot:
743 if (c == 'o')
744 _client_handshake_state = response_line_proto;
745 else
746 _client_handshake_state = client_handshake_error;
747 break;
748 case response_line_proto:
749 if (c == 'c')
750 _client_handshake_state = response_line_protoc;
751 else
752 _client_handshake_state = client_handshake_error;
753 break;
754 case response_line_protoc:
755 if (c == 'o')
756 _client_handshake_state = response_line_protoco;
757 else
758 _client_handshake_state = client_handshake_error;
759 break;
760 case response_line_protoco:
761 if (c == 'l')
762 _client_handshake_state = response_line_protocol;
763 else
764 _client_handshake_state = client_handshake_error;
765 break;
766 case response_line_protocol:
767 if (c == 's')
768 _client_handshake_state = response_line_protocols;
769 else
770 _client_handshake_state = client_handshake_error;
771 break;
772 case response_line_protocols:
773 if (c == '\r')
774 _client_handshake_state = response_line_cr;
775 else
776 _client_handshake_state = client_handshake_error;
777 break;
778 case response_line_cr:
779 if (c == '\n')
780 _client_handshake_state = client_header_field_begin_name;
781 else
782 _client_handshake_state = client_handshake_error;
783 break;
784 case client_header_field_begin_name:
785 switch (c) {
786 case '\r':
787 _client_handshake_state = client_handshake_end_line_cr;
788 break;
789 case '\n':
790 _client_handshake_state = client_handshake_error;
791 break;
792 default:
793 _header_name[0] = c;
794 _header_name_position = 1;
795 _client_handshake_state = client_header_field_name;
796 break;
797 }
798 break;
799 case client_header_field_name:
800 if (c == '\r' || c == '\n')
801 _client_handshake_state = client_handshake_error;
802 else if (c == ':') {
803 _header_name[_header_name_position] = '\0';
804 _client_handshake_state = client_header_field_colon;
805 } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
806 _client_handshake_state = client_handshake_error;
807 else {
808 _header_name[_header_name_position] = c;
809 _header_name_position++;
810 _client_handshake_state = client_header_field_name;
811 }
812 break;
813 case client_header_field_colon:
814 case client_header_field_value_trailing_space:
815 if (c == '\n')
816 _client_handshake_state = client_handshake_error;
817 else if (c == '\r')
818 _client_handshake_state = client_header_field_cr;
819 else if (c == ' ')
820 _client_handshake_state =
821 client_header_field_value_trailing_space;
822 else {
823 _header_value[0] = c;
824 _header_value_position = 1;
825 _client_handshake_state = client_header_field_value;
826 }
827 break;
828 case client_header_field_value:
829 if (c == '\n')
830 _client_handshake_state = client_handshake_error;
831 else if (c == '\r') {
832 _header_value[_header_value_position] = '\0';
833
834 if (strcasecmp ("upgrade", _header_name) == 0)
835 _header_upgrade_websocket =
836 strcasecmp ("websocket", _header_value) == 0;
837 else if (strcasecmp ("connection", _header_name) == 0)
838 _header_connection_upgrade =
839 strcasecmp ("upgrade", _header_value) == 0;
840 else if (strcasecmp ("Sec-WebSocket-Accept", _header_name)
841 == 0)
842 strcpy_s (_websocket_accept, _header_value);
843 else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
844 == 0) {
845 if (_mechanism) {
846 _client_handshake_state = client_handshake_error;
847 break;
848 }
849 if (select_protocol (_header_value))
850 strcpy_s (_websocket_protocol, _header_value);
851 }
852 _client_handshake_state = client_header_field_cr;
853 } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
854 _client_handshake_state = client_handshake_error;
855 else {
856 _header_value[_header_value_position] = c;
857 _header_value_position++;
858 _client_handshake_state = client_header_field_value;
859 }
860 break;
861 case client_header_field_cr:
862 if (c == '\n')
863 _client_handshake_state = client_header_field_begin_name;
864 else
865 _client_handshake_state = client_handshake_error;
866 break;
867 case client_handshake_end_line_cr:
868 if (c == '\n') {
869 if (_header_connection_upgrade && _header_upgrade_websocket
870 && _websocket_protocol[0] != '\0'
871 && _websocket_accept[0] != '\0') {
872 _client_handshake_state = client_handshake_complete;
873
874 // TODO: validate accept key
875
876 _inpos++;
877 _insize--;
878
879 return true;
880 }
881 _client_handshake_state = client_handshake_error;
882 } else
883 _client_handshake_state = client_handshake_error;
884 break;
885 default:
886 assert (false);
887 }
888
889 _inpos++;
890 _insize--;
891
892 if (_client_handshake_state == client_handshake_error) {
893 socket ()->event_handshake_failed_protocol (
894 _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
895
896 error (zmq::i_engine::protocol_error);
897 return false;
898 }
899 }
900
901 return false;
902 }
903
decode_and_push(msg_t * msg_)904 int zmq::ws_engine_t::decode_and_push (msg_t *msg_)
905 {
906 zmq_assert (_mechanism != NULL);
907
908 // with WS engine, ping and pong commands are control messages and should not go through any mechanism
909 if (msg_->is_ping () || msg_->is_pong () || msg_->is_close_cmd ()) {
910 if (process_command_message (msg_) == -1)
911 return -1;
912 } else if (_mechanism->decode (msg_) == -1)
913 return -1;
914
915 if (_has_timeout_timer) {
916 _has_timeout_timer = false;
917 cancel_timer (heartbeat_timeout_timer_id);
918 }
919
920 if (msg_->flags () & msg_t::command && !msg_->is_ping ()
921 && !msg_->is_pong () && !msg_->is_close_cmd ())
922 process_command_message (msg_);
923
924 if (_metadata)
925 msg_->set_metadata (_metadata);
926 if (session ()->push_msg (msg_) == -1) {
927 if (errno == EAGAIN)
928 _process_msg = &ws_engine_t::push_one_then_decode_and_push;
929 return -1;
930 }
931 return 0;
932 }
933
produce_close_message(msg_t * msg_)934 int zmq::ws_engine_t::produce_close_message (msg_t *msg_)
935 {
936 int rc = msg_->move (_close_msg);
937 errno_assert (rc == 0);
938
939 _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
940 &ws_engine_t::produce_no_msg_after_close);
941
942 return rc;
943 }
944
produce_no_msg_after_close(msg_t * msg_)945 int zmq::ws_engine_t::produce_no_msg_after_close (msg_t *msg_)
946 {
947 LIBZMQ_UNUSED (msg_);
948 _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
949 &ws_engine_t::close_connection_after_close);
950
951 errno = EAGAIN;
952 return -1;
953 }
954
close_connection_after_close(msg_t * msg_)955 int zmq::ws_engine_t::close_connection_after_close (msg_t *msg_)
956 {
957 LIBZMQ_UNUSED (msg_);
958 error (connection_error);
959 errno = ECONNRESET;
960 return -1;
961 }
962
produce_ping_message(msg_t * msg_)963 int zmq::ws_engine_t::produce_ping_message (msg_t *msg_)
964 {
965 int rc = msg_->init ();
966 errno_assert (rc == 0);
967 msg_->set_flags (msg_t::command | msg_t::ping);
968
969 _next_msg = &ws_engine_t::pull_and_encode;
970 if (!_has_timeout_timer && _heartbeat_timeout > 0) {
971 add_timer (_heartbeat_timeout, heartbeat_timeout_timer_id);
972 _has_timeout_timer = true;
973 }
974
975 return rc;
976 }
977
978
produce_pong_message(msg_t * msg_)979 int zmq::ws_engine_t::produce_pong_message (msg_t *msg_)
980 {
981 int rc = msg_->init ();
982 errno_assert (rc == 0);
983 msg_->set_flags (msg_t::command | msg_t::pong);
984
985 _next_msg = &ws_engine_t::pull_and_encode;
986 return rc;
987 }
988
989
process_command_message(msg_t * msg_)990 int zmq::ws_engine_t::process_command_message (msg_t *msg_)
991 {
992 if (msg_->is_ping ()) {
993 _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
994 &ws_engine_t::produce_pong_message);
995 out_event ();
996 } else if (msg_->is_close_cmd ()) {
997 int rc = _close_msg.copy (*msg_);
998 errno_assert (rc == 0);
999 _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
1000 &ws_engine_t::produce_close_message);
1001 out_event ();
1002 }
1003
1004 return 0;
1005 }
1006
1007 static int
encode_base64(const unsigned char * in_,int in_len_,char * out_,int out_len_)1008 encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_)
1009 {
1010 static const unsigned char base64enc_tab[65] =
1011 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1012
1013 int io = 0;
1014 uint32_t v = 0;
1015 int rem = 0;
1016
1017 for (int ii = 0; ii < in_len_; ii++) {
1018 unsigned char ch;
1019 ch = in_[ii];
1020 v = (v << 8) | ch;
1021 rem += 8;
1022 while (rem >= 6) {
1023 rem -= 6;
1024 if (io >= out_len_)
1025 return -1; /* truncation is failure */
1026 out_[io++] = base64enc_tab[(v >> rem) & 63];
1027 }
1028 }
1029 if (rem) {
1030 v <<= (6 - rem);
1031 if (io >= out_len_)
1032 return -1; /* truncation is failure */
1033 out_[io++] = base64enc_tab[v & 63];
1034 }
1035 while (io & 3) {
1036 if (io >= out_len_)
1037 return -1; /* truncation is failure */
1038 out_[io++] = '=';
1039 }
1040 if (io >= out_len_)
1041 return -1; /* no room for null terminator */
1042 out_[io] = 0;
1043 return io;
1044 }
1045
compute_accept_key(char * key_,unsigned char * hash_)1046 static void compute_accept_key (char *key_, unsigned char *hash_)
1047 {
1048 const char *magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1049 #ifdef ZMQ_USE_NSS
1050 unsigned int len;
1051 HASH_HashType type = HASH_GetHashTypeByOidTag (SEC_OID_SHA1);
1052 HASHContext *ctx = HASH_Create (type);
1053 assert (ctx);
1054
1055 HASH_Begin (ctx);
1056 HASH_Update (ctx, (unsigned char *) key_, (unsigned int) strlen (key_));
1057 HASH_Update (ctx, (unsigned char *) magic_string,
1058 (unsigned int) strlen (magic_string));
1059 HASH_End (ctx, hash_, &len, SHA_DIGEST_LENGTH);
1060 HASH_Destroy (ctx);
1061 #elif defined ZMQ_USE_BUILTIN_SHA1
1062 sha1_ctxt ctx;
1063 SHA1_Init (&ctx);
1064 SHA1_Update (&ctx, (unsigned char *) key_, strlen (key_));
1065 SHA1_Update (&ctx, (unsigned char *) magic_string, strlen (magic_string));
1066
1067 SHA1_Final (hash_, &ctx);
1068 #elif defined ZMQ_USE_GNUTLS
1069 gnutls_hash_hd_t hd;
1070 gnutls_hash_init (&hd, GNUTLS_DIG_SHA1);
1071 gnutls_hash (hd, key_, strlen (key_));
1072 gnutls_hash (hd, magic_string, strlen (magic_string));
1073 gnutls_hash_deinit (hd, hash_);
1074 #else
1075 #error "No sha1 implementation set"
1076 #endif
1077 }
1078