1 /*
2 Copyright (c) 2007-2019 Contributors as noted in the AUTHORS file
3 
4 This file is part of libzmq, the ZeroMQ core engine in C++.
5 
6 libzmq is free software; you can redistribute it and/or modify it under
7 the terms of the GNU Lesser General Public License (LGPL) as published
8 by the Free Software Foundation; either version 3 of the License, or
9 (at your option) any later version.
10 
11 As a special exception, the Contributors give you permission to link
12 this library with independent modules to produce an executable,
13 regardless of the license terms of these independent modules, and to
14 copy and distribute the resulting executable under terms of your choice,
15 provided that you also meet, for each linked independent module, the
16 terms and conditions of the license of that module. An independent
17 module is a module which is not derived from or based on this library.
18 If you modify this library, you must extend this exception to your
19 version of the library.
20 
21 libzmq is distributed in the hope that it will be useful, but WITHOUT
22 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
23 FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
24 License for more details.
25 
26 You should have received a copy of the GNU Lesser General Public License
27 along with this program.  If not, see <http://www.gnu.org/licenses/>.
28 */
29 
30 #include "precompiled.hpp"
31 
32 #ifdef ZMQ_USE_NSS
33 #include <secoid.h>
34 #include <sechash.h>
35 #define SHA_DIGEST_LENGTH 20
36 #elif defined ZMQ_USE_BUILTIN_SHA1
37 #include "../external/sha1/sha1.h"
38 #elif defined ZMQ_USE_GNUTLS
39 #define SHA_DIGEST_LENGTH 20
40 #include <gnutls/gnutls.h>
41 #include <gnutls/crypto.h>
42 #endif
43 
44 #if !defined ZMQ_HAVE_WINDOWS
45 #include <sys/types.h>
46 #include <unistd.h>
47 #include <sys/socket.h>
48 #include <netinet/in.h>
49 #include <arpa/inet.h>
50 #ifdef ZMQ_HAVE_VXWORKS
51 #include <sockLib.h>
52 #endif
53 #endif
54 
55 #include <cstring>
56 
57 #include "compat.hpp"
58 #include "tcp.hpp"
59 #include "ws_engine.hpp"
60 #include "session_base.hpp"
61 #include "err.hpp"
62 #include "ip.hpp"
63 #include "random.hpp"
64 #include "ws_decoder.hpp"
65 #include "ws_encoder.hpp"
66 #include "null_mechanism.hpp"
67 #include "plain_server.hpp"
68 #include "plain_client.hpp"
69 
70 #ifdef ZMQ_HAVE_CURVE
71 #include "curve_client.hpp"
72 #include "curve_server.hpp"
73 #endif
74 
75 //  OSX uses a different name for this socket option
76 #ifndef IPV6_ADD_MEMBERSHIP
77 #define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP
78 #endif
79 
80 #ifdef __APPLE__
81 #include <TargetConditionals.h>
82 #endif
83 
84 static int
85 encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_);
86 
87 static void compute_accept_key (char *key_,
88                                 unsigned char hash_[SHA_DIGEST_LENGTH]);
89 
ws_engine_t(fd_t fd_,const options_t & options_,const endpoint_uri_pair_t & endpoint_uri_pair_,const ws_address_t & address_,bool client_)90 zmq::ws_engine_t::ws_engine_t (fd_t fd_,
91                                const options_t &options_,
92                                const endpoint_uri_pair_t &endpoint_uri_pair_,
93                                const ws_address_t &address_,
94                                bool client_) :
95     stream_engine_base_t (fd_, options_, endpoint_uri_pair_, true),
96     _client (client_),
97     _address (address_),
98     _client_handshake_state (client_handshake_initial),
99     _server_handshake_state (handshake_initial),
100     _header_name_position (0),
101     _header_value_position (0),
102     _header_upgrade_websocket (false),
103     _header_connection_upgrade (false),
104     _heartbeat_timeout (0)
105 {
106     memset (_websocket_key, 0, MAX_HEADER_VALUE_LENGTH + 1);
107     memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
108     memset (_websocket_protocol, 0, 256);
109 
110     _next_msg = &ws_engine_t::next_handshake_command;
111     _process_msg = &ws_engine_t::process_handshake_command;
112     _close_msg.init ();
113 
114     if (_options.heartbeat_interval > 0) {
115         _heartbeat_timeout = _options.heartbeat_timeout;
116         if (_heartbeat_timeout == -1)
117             _heartbeat_timeout = _options.heartbeat_interval;
118     }
119 }
120 
~ws_engine_t()121 zmq::ws_engine_t::~ws_engine_t ()
122 {
123     _close_msg.close ();
124 }
125 
start_ws_handshake()126 void zmq::ws_engine_t::start_ws_handshake ()
127 {
128     if (_client) {
129         const char *protocol;
130         if (_options.mechanism == ZMQ_NULL)
131             protocol = "ZWS2.0/NULL,ZWS2.0";
132         else if (_options.mechanism == ZMQ_PLAIN)
133             protocol = "ZWS2.0/PLAIN";
134 #ifdef ZMQ_HAVE_CURVE
135         else if (_options.mechanism == ZMQ_CURVE)
136             protocol = "ZWS2.0/CURVE";
137 #endif
138         else {
139             // Avoid unitialized variable error breaking UWP build
140             protocol = "";
141             assert (false);
142         }
143 
144         unsigned char nonce[16];
145         int *p = reinterpret_cast<int *> (nonce);
146 
147         // The nonce doesn't have to be secure one, it is just use to avoid proxy cache
148         *p = zmq::generate_random ();
149         *(p + 1) = zmq::generate_random ();
150         *(p + 2) = zmq::generate_random ();
151         *(p + 3) = zmq::generate_random ();
152 
153         int size =
154           encode_base64 (nonce, 16, _websocket_key, MAX_HEADER_VALUE_LENGTH);
155         assert (size > 0);
156 
157         size = snprintf (
158           reinterpret_cast<char *> (_write_buffer), WS_BUFFER_SIZE,
159           "GET %s HTTP/1.1\r\n"
160           "Host: %s\r\n"
161           "Upgrade: websocket\r\n"
162           "Connection: Upgrade\r\n"
163           "Sec-WebSocket-Key: %s\r\n"
164           "Sec-WebSocket-Protocol: %s\r\n"
165           "Sec-WebSocket-Version: 13\r\n\r\n",
166           _address.path (), _address.host (), _websocket_key, protocol);
167         assert (size > 0 && size < WS_BUFFER_SIZE);
168         _outpos = _write_buffer;
169         _outsize = size;
170         set_pollout ();
171     }
172 }
173 
plug_internal()174 void zmq::ws_engine_t::plug_internal ()
175 {
176     start_ws_handshake ();
177     set_pollin ();
178     in_event ();
179 }
180 
routing_id_msg(msg_t * msg_)181 int zmq::ws_engine_t::routing_id_msg (msg_t *msg_)
182 {
183     const int rc = msg_->init_size (_options.routing_id_size);
184     errno_assert (rc == 0);
185     if (_options.routing_id_size > 0)
186         memcpy (msg_->data (), _options.routing_id, _options.routing_id_size);
187     _next_msg = &ws_engine_t::pull_msg_from_session;
188 
189     return 0;
190 }
191 
process_routing_id_msg(msg_t * msg_)192 int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
193 {
194     if (_options.recv_routing_id) {
195         msg_->set_flags (msg_t::routing_id);
196         const int rc = session ()->push_msg (msg_);
197         errno_assert (rc == 0);
198     } else {
199         int rc = msg_->close ();
200         errno_assert (rc == 0);
201         rc = msg_->init ();
202         errno_assert (rc == 0);
203     }
204 
205     _process_msg = &ws_engine_t::push_msg_to_session;
206 
207     return 0;
208 }
209 
select_protocol(const char * protocol_)210 bool zmq::ws_engine_t::select_protocol (const char *protocol_)
211 {
212     if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol_) == 0)) {
213         _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
214           &ws_engine_t::routing_id_msg);
215         _process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
216           &ws_engine_t::process_routing_id_msg);
217 
218         // No mechanism in place, enabling heartbeat
219         if (_options.heartbeat_interval > 0 && !_has_heartbeat_timer) {
220             add_timer (_options.heartbeat_interval, heartbeat_ivl_timer_id);
221             _has_heartbeat_timer = true;
222         }
223 
224         return true;
225     }
226     if (_options.mechanism == ZMQ_NULL
227         && strcmp ("ZWS2.0/NULL", protocol_) == 0) {
228         _mechanism = new (std::nothrow)
229           null_mechanism_t (session (), _peer_address, _options);
230         alloc_assert (_mechanism);
231         return true;
232     } else if (_options.mechanism == ZMQ_PLAIN
233                && strcmp ("ZWS2.0/PLAIN", protocol_) == 0) {
234         if (_options.as_server)
235             _mechanism = new (std::nothrow)
236               plain_server_t (session (), _peer_address, _options);
237         else
238             _mechanism =
239               new (std::nothrow) plain_client_t (session (), _options);
240         alloc_assert (_mechanism);
241         return true;
242     }
243 #ifdef ZMQ_HAVE_CURVE
244     else if (_options.mechanism == ZMQ_CURVE
245              && strcmp ("ZWS2.0/CURVE", protocol_) == 0) {
246         if (_options.as_server)
247             _mechanism = new (std::nothrow)
248               curve_server_t (session (), _peer_address, _options, false);
249         else
250             _mechanism =
251               new (std::nothrow) curve_client_t (session (), _options, false);
252         alloc_assert (_mechanism);
253         return true;
254     }
255 #endif
256 
257     return false;
258 }
259 
handshake()260 bool zmq::ws_engine_t::handshake ()
261 {
262     bool complete;
263 
264     if (_client)
265         complete = client_handshake ();
266     else
267         complete = server_handshake ();
268 
269     if (complete) {
270         _encoder =
271           new (std::nothrow) ws_encoder_t (_options.out_batch_size, _client);
272         alloc_assert (_encoder);
273 
274         _decoder = new (std::nothrow)
275           ws_decoder_t (_options.in_batch_size, _options.maxmsgsize,
276                         _options.zero_copy, !_client);
277         alloc_assert (_decoder);
278 
279         socket ()->event_handshake_succeeded (_endpoint_uri_pair, 0);
280 
281         set_pollout ();
282     }
283 
284     return complete;
285 }
286 
server_handshake()287 bool zmq::ws_engine_t::server_handshake ()
288 {
289     const int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
290     if (nbytes == -1) {
291         if (errno != EAGAIN)
292             error (zmq::i_engine::connection_error);
293         return false;
294     }
295 
296     _inpos = _read_buffer;
297     _insize = nbytes;
298 
299     while (_insize > 0) {
300         const char c = static_cast<char> (*_inpos);
301 
302         switch (_server_handshake_state) {
303             case handshake_initial:
304                 if (c == 'G')
305                     _server_handshake_state = request_line_G;
306                 else
307                     _server_handshake_state = handshake_error;
308                 break;
309             case request_line_G:
310                 if (c == 'E')
311                     _server_handshake_state = request_line_GE;
312                 else
313                     _server_handshake_state = handshake_error;
314                 break;
315             case request_line_GE:
316                 if (c == 'T')
317                     _server_handshake_state = request_line_GET;
318                 else
319                     _server_handshake_state = handshake_error;
320                 break;
321             case request_line_GET:
322                 if (c == ' ')
323                     _server_handshake_state = request_line_GET_space;
324                 else
325                     _server_handshake_state = handshake_error;
326                 break;
327             case request_line_GET_space:
328                 if (c == '\r' || c == '\n')
329                     _server_handshake_state = handshake_error;
330                 // TODO: instead of check what is not allowed check what is allowed
331                 if (c != ' ')
332                     _server_handshake_state = request_line_resource;
333                 else
334                     _server_handshake_state = request_line_GET_space;
335                 break;
336             case request_line_resource:
337                 if (c == '\r' || c == '\n')
338                     _server_handshake_state = handshake_error;
339                 else if (c == ' ')
340                     _server_handshake_state = request_line_resource_space;
341                 else
342                     _server_handshake_state = request_line_resource;
343                 break;
344             case request_line_resource_space:
345                 if (c == 'H')
346                     _server_handshake_state = request_line_H;
347                 else
348                     _server_handshake_state = handshake_error;
349                 break;
350             case request_line_H:
351                 if (c == 'T')
352                     _server_handshake_state = request_line_HT;
353                 else
354                     _server_handshake_state = handshake_error;
355                 break;
356             case request_line_HT:
357                 if (c == 'T')
358                     _server_handshake_state = request_line_HTT;
359                 else
360                     _server_handshake_state = handshake_error;
361                 break;
362             case request_line_HTT:
363                 if (c == 'P')
364                     _server_handshake_state = request_line_HTTP;
365                 else
366                     _server_handshake_state = handshake_error;
367                 break;
368             case request_line_HTTP:
369                 if (c == '/')
370                     _server_handshake_state = request_line_HTTP_slash;
371                 else
372                     _server_handshake_state = handshake_error;
373                 break;
374             case request_line_HTTP_slash:
375                 if (c == '1')
376                     _server_handshake_state = request_line_HTTP_slash_1;
377                 else
378                     _server_handshake_state = handshake_error;
379                 break;
380             case request_line_HTTP_slash_1:
381                 if (c == '.')
382                     _server_handshake_state = request_line_HTTP_slash_1_dot;
383                 else
384                     _server_handshake_state = handshake_error;
385                 break;
386             case request_line_HTTP_slash_1_dot:
387                 if (c == '1')
388                     _server_handshake_state = request_line_HTTP_slash_1_dot_1;
389                 else
390                     _server_handshake_state = handshake_error;
391                 break;
392             case request_line_HTTP_slash_1_dot_1:
393                 if (c == '\r')
394                     _server_handshake_state = request_line_cr;
395                 else
396                     _server_handshake_state = handshake_error;
397                 break;
398             case request_line_cr:
399                 if (c == '\n')
400                     _server_handshake_state = header_field_begin_name;
401                 else
402                     _server_handshake_state = handshake_error;
403                 break;
404             case header_field_begin_name:
405                 switch (c) {
406                     case '\r':
407                         _server_handshake_state = handshake_end_line_cr;
408                         break;
409                     case '\n':
410                         _server_handshake_state = handshake_error;
411                         break;
412                     default:
413                         _header_name[0] = c;
414                         _header_name_position = 1;
415                         _server_handshake_state = header_field_name;
416                         break;
417                 }
418                 break;
419             case header_field_name:
420                 if (c == '\r' || c == '\n')
421                     _server_handshake_state = handshake_error;
422                 else if (c == ':') {
423                     _header_name[_header_name_position] = '\0';
424                     _server_handshake_state = header_field_colon;
425                 } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
426                     _server_handshake_state = handshake_error;
427                 else {
428                     _header_name[_header_name_position] = c;
429                     _header_name_position++;
430                     _server_handshake_state = header_field_name;
431                 }
432                 break;
433             case header_field_colon:
434             case header_field_value_trailing_space:
435                 if (c == '\n')
436                     _server_handshake_state = handshake_error;
437                 else if (c == '\r')
438                     _server_handshake_state = header_field_cr;
439                 else if (c == ' ')
440                     _server_handshake_state = header_field_value_trailing_space;
441                 else {
442                     _header_value[0] = c;
443                     _header_value_position = 1;
444                     _server_handshake_state = header_field_value;
445                 }
446                 break;
447             case header_field_value:
448                 if (c == '\n')
449                     _server_handshake_state = handshake_error;
450                 else if (c == '\r') {
451                     _header_value[_header_value_position] = '\0';
452 
453                     if (strcasecmp ("upgrade", _header_name) == 0)
454                         _header_upgrade_websocket =
455                           strcasecmp ("websocket", _header_value) == 0;
456                     else if (strcasecmp ("connection", _header_name) == 0)
457                         _header_connection_upgrade =
458                           strcasecmp ("upgrade", _header_value) == 0;
459                     else if (strcasecmp ("Sec-WebSocket-Key", _header_name)
460                              == 0)
461                         strcpy_s (_websocket_key, _header_value);
462                     else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
463                              == 0) {
464                         // Currently only the ZWS2.0 is supported
465                         // Sec-WebSocket-Protocol can appear multiple times or be a comma separated list
466                         // if _websocket_protocol is already set we skip the check
467                         if (_websocket_protocol[0] == '\0') {
468                             char *rest = 0;
469                             char *p = strtok_r (_header_value, ",", &rest);
470                             while (p != NULL) {
471                                 if (*p == ' ')
472                                     p++;
473 
474                                 if (select_protocol (p)) {
475                                     strcpy_s (_websocket_protocol, p);
476                                     break;
477                                 }
478 
479                                 p = strtok_r (NULL, ",", &rest);
480                             }
481                         }
482                     }
483 
484                     _server_handshake_state = header_field_cr;
485                 } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
486                     _server_handshake_state = handshake_error;
487                 else {
488                     _header_value[_header_value_position] = c;
489                     _header_value_position++;
490                     _server_handshake_state = header_field_value;
491                 }
492                 break;
493             case header_field_cr:
494                 if (c == '\n')
495                     _server_handshake_state = header_field_begin_name;
496                 else
497                     _server_handshake_state = handshake_error;
498                 break;
499             case handshake_end_line_cr:
500                 if (c == '\n') {
501                     if (_header_connection_upgrade && _header_upgrade_websocket
502                         && _websocket_protocol[0] != '\0'
503                         && _websocket_key[0] != '\0') {
504                         _server_handshake_state = handshake_complete;
505 
506                         unsigned char hash[SHA_DIGEST_LENGTH];
507                         compute_accept_key (_websocket_key, hash);
508 
509                         const int accept_key_len = encode_base64 (
510                           hash, SHA_DIGEST_LENGTH, _websocket_accept,
511                           MAX_HEADER_VALUE_LENGTH);
512                         assert (accept_key_len > 0);
513                         _websocket_accept[accept_key_len] = '\0';
514 
515                         const int written =
516                           snprintf (reinterpret_cast<char *> (_write_buffer),
517                                     WS_BUFFER_SIZE,
518                                     "HTTP/1.1 101 Switching Protocols\r\n"
519                                     "Upgrade: websocket\r\n"
520                                     "Connection: Upgrade\r\n"
521                                     "Sec-WebSocket-Accept: %s\r\n"
522                                     "Sec-WebSocket-Protocol: %s\r\n"
523                                     "\r\n",
524                                     _websocket_accept, _websocket_protocol);
525                         assert (written >= 0 && written < WS_BUFFER_SIZE);
526                         _outpos = _write_buffer;
527                         _outsize = written;
528 
529                         _inpos++;
530                         _insize--;
531 
532                         return true;
533                     }
534                     _server_handshake_state = handshake_error;
535                 } else
536                     _server_handshake_state = handshake_error;
537                 break;
538             default:
539                 assert (false);
540         }
541 
542         _inpos++;
543         _insize--;
544 
545         if (_server_handshake_state == handshake_error) {
546             // TODO: send bad request
547 
548             socket ()->event_handshake_failed_protocol (
549               _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
550 
551             error (zmq::i_engine::protocol_error);
552             return false;
553         }
554     }
555     return false;
556 }
557 
client_handshake()558 bool zmq::ws_engine_t::client_handshake ()
559 {
560     const int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
561     if (nbytes == -1) {
562         if (errno != EAGAIN)
563             error (zmq::i_engine::connection_error);
564         return false;
565     }
566 
567     _inpos = _read_buffer;
568     _insize = nbytes;
569 
570     while (_insize > 0) {
571         const char c = static_cast<char> (*_inpos);
572 
573         switch (_client_handshake_state) {
574             case client_handshake_initial:
575                 if (c == 'H')
576                     _client_handshake_state = response_line_H;
577                 else
578                     _client_handshake_state = client_handshake_error;
579                 break;
580             case response_line_H:
581                 if (c == 'T')
582                     _client_handshake_state = response_line_HT;
583                 else
584                     _client_handshake_state = client_handshake_error;
585                 break;
586             case response_line_HT:
587                 if (c == 'T')
588                     _client_handshake_state = response_line_HTT;
589                 else
590                     _client_handshake_state = client_handshake_error;
591                 break;
592             case response_line_HTT:
593                 if (c == 'P')
594                     _client_handshake_state = response_line_HTTP;
595                 else
596                     _client_handshake_state = client_handshake_error;
597                 break;
598             case response_line_HTTP:
599                 if (c == '/')
600                     _client_handshake_state = response_line_HTTP_slash;
601                 else
602                     _client_handshake_state = client_handshake_error;
603                 break;
604             case response_line_HTTP_slash:
605                 if (c == '1')
606                     _client_handshake_state = response_line_HTTP_slash_1;
607                 else
608                     _client_handshake_state = client_handshake_error;
609                 break;
610             case response_line_HTTP_slash_1:
611                 if (c == '.')
612                     _client_handshake_state = response_line_HTTP_slash_1_dot;
613                 else
614                     _client_handshake_state = client_handshake_error;
615                 break;
616             case response_line_HTTP_slash_1_dot:
617                 if (c == '1')
618                     _client_handshake_state = response_line_HTTP_slash_1_dot_1;
619                 else
620                     _client_handshake_state = client_handshake_error;
621                 break;
622             case response_line_HTTP_slash_1_dot_1:
623                 if (c == ' ')
624                     _client_handshake_state =
625                       response_line_HTTP_slash_1_dot_1_space;
626                 else
627                     _client_handshake_state = client_handshake_error;
628                 break;
629             case response_line_HTTP_slash_1_dot_1_space:
630                 if (c == ' ')
631                     _client_handshake_state =
632                       response_line_HTTP_slash_1_dot_1_space;
633                 else if (c == '1')
634                     _client_handshake_state = response_line_status_1;
635                 else
636                     _client_handshake_state = client_handshake_error;
637                 break;
638             case response_line_status_1:
639                 if (c == '0')
640                     _client_handshake_state = response_line_status_10;
641                 else
642                     _client_handshake_state = client_handshake_error;
643                 break;
644             case response_line_status_10:
645                 if (c == '1')
646                     _client_handshake_state = response_line_status_101;
647                 else
648                     _client_handshake_state = client_handshake_error;
649                 break;
650             case response_line_status_101:
651                 if (c == ' ')
652                     _client_handshake_state = response_line_status_101_space;
653                 else
654                     _client_handshake_state = client_handshake_error;
655                 break;
656             case response_line_status_101_space:
657                 if (c == ' ')
658                     _client_handshake_state = response_line_status_101_space;
659                 else if (c == 'S')
660                     _client_handshake_state = response_line_s;
661                 else
662                     _client_handshake_state = client_handshake_error;
663                 break;
664             case response_line_s:
665                 if (c == 'w')
666                     _client_handshake_state = response_line_sw;
667                 else
668                     _client_handshake_state = client_handshake_error;
669                 break;
670             case response_line_sw:
671                 if (c == 'i')
672                     _client_handshake_state = response_line_swi;
673                 else
674                     _client_handshake_state = client_handshake_error;
675                 break;
676             case response_line_swi:
677                 if (c == 't')
678                     _client_handshake_state = response_line_swit;
679                 else
680                     _client_handshake_state = client_handshake_error;
681                 break;
682             case response_line_swit:
683                 if (c == 'c')
684                     _client_handshake_state = response_line_switc;
685                 else
686                     _client_handshake_state = client_handshake_error;
687                 break;
688             case response_line_switc:
689                 if (c == 'h')
690                     _client_handshake_state = response_line_switch;
691                 else
692                     _client_handshake_state = client_handshake_error;
693                 break;
694             case response_line_switch:
695                 if (c == 'i')
696                     _client_handshake_state = response_line_switchi;
697                 else
698                     _client_handshake_state = client_handshake_error;
699                 break;
700             case response_line_switchi:
701                 if (c == 'n')
702                     _client_handshake_state = response_line_switchin;
703                 else
704                     _client_handshake_state = client_handshake_error;
705                 break;
706             case response_line_switchin:
707                 if (c == 'g')
708                     _client_handshake_state = response_line_switching;
709                 else
710                     _client_handshake_state = client_handshake_error;
711                 break;
712             case response_line_switching:
713                 if (c == ' ')
714                     _client_handshake_state = response_line_switching_space;
715                 else
716                     _client_handshake_state = client_handshake_error;
717                 break;
718             case response_line_switching_space:
719                 if (c == 'P')
720                     _client_handshake_state = response_line_p;
721                 else
722                     _client_handshake_state = client_handshake_error;
723                 break;
724             case response_line_p:
725                 if (c == 'r')
726                     _client_handshake_state = response_line_pr;
727                 else
728                     _client_handshake_state = client_handshake_error;
729                 break;
730             case response_line_pr:
731                 if (c == 'o')
732                     _client_handshake_state = response_line_pro;
733                 else
734                     _client_handshake_state = client_handshake_error;
735                 break;
736             case response_line_pro:
737                 if (c == 't')
738                     _client_handshake_state = response_line_prot;
739                 else
740                     _client_handshake_state = client_handshake_error;
741                 break;
742             case response_line_prot:
743                 if (c == 'o')
744                     _client_handshake_state = response_line_proto;
745                 else
746                     _client_handshake_state = client_handshake_error;
747                 break;
748             case response_line_proto:
749                 if (c == 'c')
750                     _client_handshake_state = response_line_protoc;
751                 else
752                     _client_handshake_state = client_handshake_error;
753                 break;
754             case response_line_protoc:
755                 if (c == 'o')
756                     _client_handshake_state = response_line_protoco;
757                 else
758                     _client_handshake_state = client_handshake_error;
759                 break;
760             case response_line_protoco:
761                 if (c == 'l')
762                     _client_handshake_state = response_line_protocol;
763                 else
764                     _client_handshake_state = client_handshake_error;
765                 break;
766             case response_line_protocol:
767                 if (c == 's')
768                     _client_handshake_state = response_line_protocols;
769                 else
770                     _client_handshake_state = client_handshake_error;
771                 break;
772             case response_line_protocols:
773                 if (c == '\r')
774                     _client_handshake_state = response_line_cr;
775                 else
776                     _client_handshake_state = client_handshake_error;
777                 break;
778             case response_line_cr:
779                 if (c == '\n')
780                     _client_handshake_state = client_header_field_begin_name;
781                 else
782                     _client_handshake_state = client_handshake_error;
783                 break;
784             case client_header_field_begin_name:
785                 switch (c) {
786                     case '\r':
787                         _client_handshake_state = client_handshake_end_line_cr;
788                         break;
789                     case '\n':
790                         _client_handshake_state = client_handshake_error;
791                         break;
792                     default:
793                         _header_name[0] = c;
794                         _header_name_position = 1;
795                         _client_handshake_state = client_header_field_name;
796                         break;
797                 }
798                 break;
799             case client_header_field_name:
800                 if (c == '\r' || c == '\n')
801                     _client_handshake_state = client_handshake_error;
802                 else if (c == ':') {
803                     _header_name[_header_name_position] = '\0';
804                     _client_handshake_state = client_header_field_colon;
805                 } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
806                     _client_handshake_state = client_handshake_error;
807                 else {
808                     _header_name[_header_name_position] = c;
809                     _header_name_position++;
810                     _client_handshake_state = client_header_field_name;
811                 }
812                 break;
813             case client_header_field_colon:
814             case client_header_field_value_trailing_space:
815                 if (c == '\n')
816                     _client_handshake_state = client_handshake_error;
817                 else if (c == '\r')
818                     _client_handshake_state = client_header_field_cr;
819                 else if (c == ' ')
820                     _client_handshake_state =
821                       client_header_field_value_trailing_space;
822                 else {
823                     _header_value[0] = c;
824                     _header_value_position = 1;
825                     _client_handshake_state = client_header_field_value;
826                 }
827                 break;
828             case client_header_field_value:
829                 if (c == '\n')
830                     _client_handshake_state = client_handshake_error;
831                 else if (c == '\r') {
832                     _header_value[_header_value_position] = '\0';
833 
834                     if (strcasecmp ("upgrade", _header_name) == 0)
835                         _header_upgrade_websocket =
836                           strcasecmp ("websocket", _header_value) == 0;
837                     else if (strcasecmp ("connection", _header_name) == 0)
838                         _header_connection_upgrade =
839                           strcasecmp ("upgrade", _header_value) == 0;
840                     else if (strcasecmp ("Sec-WebSocket-Accept", _header_name)
841                              == 0)
842                         strcpy_s (_websocket_accept, _header_value);
843                     else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
844                              == 0) {
845                         if (_mechanism) {
846                             _client_handshake_state = client_handshake_error;
847                             break;
848                         }
849                         if (select_protocol (_header_value))
850                             strcpy_s (_websocket_protocol, _header_value);
851                     }
852                     _client_handshake_state = client_header_field_cr;
853                 } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
854                     _client_handshake_state = client_handshake_error;
855                 else {
856                     _header_value[_header_value_position] = c;
857                     _header_value_position++;
858                     _client_handshake_state = client_header_field_value;
859                 }
860                 break;
861             case client_header_field_cr:
862                 if (c == '\n')
863                     _client_handshake_state = client_header_field_begin_name;
864                 else
865                     _client_handshake_state = client_handshake_error;
866                 break;
867             case client_handshake_end_line_cr:
868                 if (c == '\n') {
869                     if (_header_connection_upgrade && _header_upgrade_websocket
870                         && _websocket_protocol[0] != '\0'
871                         && _websocket_accept[0] != '\0') {
872                         _client_handshake_state = client_handshake_complete;
873 
874                         // TODO: validate accept key
875 
876                         _inpos++;
877                         _insize--;
878 
879                         return true;
880                     }
881                     _client_handshake_state = client_handshake_error;
882                 } else
883                     _client_handshake_state = client_handshake_error;
884                 break;
885             default:
886                 assert (false);
887         }
888 
889         _inpos++;
890         _insize--;
891 
892         if (_client_handshake_state == client_handshake_error) {
893             socket ()->event_handshake_failed_protocol (
894               _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
895 
896             error (zmq::i_engine::protocol_error);
897             return false;
898         }
899     }
900 
901     return false;
902 }
903 
decode_and_push(msg_t * msg_)904 int zmq::ws_engine_t::decode_and_push (msg_t *msg_)
905 {
906     zmq_assert (_mechanism != NULL);
907 
908     //  with WS engine, ping and pong commands are control messages and should not go through any mechanism
909     if (msg_->is_ping () || msg_->is_pong () || msg_->is_close_cmd ()) {
910         if (process_command_message (msg_) == -1)
911             return -1;
912     } else if (_mechanism->decode (msg_) == -1)
913         return -1;
914 
915     if (_has_timeout_timer) {
916         _has_timeout_timer = false;
917         cancel_timer (heartbeat_timeout_timer_id);
918     }
919 
920     if (msg_->flags () & msg_t::command && !msg_->is_ping ()
921         && !msg_->is_pong () && !msg_->is_close_cmd ())
922         process_command_message (msg_);
923 
924     if (_metadata)
925         msg_->set_metadata (_metadata);
926     if (session ()->push_msg (msg_) == -1) {
927         if (errno == EAGAIN)
928             _process_msg = &ws_engine_t::push_one_then_decode_and_push;
929         return -1;
930     }
931     return 0;
932 }
933 
produce_close_message(msg_t * msg_)934 int zmq::ws_engine_t::produce_close_message (msg_t *msg_)
935 {
936     int rc = msg_->move (_close_msg);
937     errno_assert (rc == 0);
938 
939     _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
940       &ws_engine_t::produce_no_msg_after_close);
941 
942     return rc;
943 }
944 
produce_no_msg_after_close(msg_t * msg_)945 int zmq::ws_engine_t::produce_no_msg_after_close (msg_t *msg_)
946 {
947     LIBZMQ_UNUSED (msg_);
948     _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
949       &ws_engine_t::close_connection_after_close);
950 
951     errno = EAGAIN;
952     return -1;
953 }
954 
close_connection_after_close(msg_t * msg_)955 int zmq::ws_engine_t::close_connection_after_close (msg_t *msg_)
956 {
957     LIBZMQ_UNUSED (msg_);
958     error (connection_error);
959     errno = ECONNRESET;
960     return -1;
961 }
962 
produce_ping_message(msg_t * msg_)963 int zmq::ws_engine_t::produce_ping_message (msg_t *msg_)
964 {
965     int rc = msg_->init ();
966     errno_assert (rc == 0);
967     msg_->set_flags (msg_t::command | msg_t::ping);
968 
969     _next_msg = &ws_engine_t::pull_and_encode;
970     if (!_has_timeout_timer && _heartbeat_timeout > 0) {
971         add_timer (_heartbeat_timeout, heartbeat_timeout_timer_id);
972         _has_timeout_timer = true;
973     }
974 
975     return rc;
976 }
977 
978 
produce_pong_message(msg_t * msg_)979 int zmq::ws_engine_t::produce_pong_message (msg_t *msg_)
980 {
981     int rc = msg_->init ();
982     errno_assert (rc == 0);
983     msg_->set_flags (msg_t::command | msg_t::pong);
984 
985     _next_msg = &ws_engine_t::pull_and_encode;
986     return rc;
987 }
988 
989 
process_command_message(msg_t * msg_)990 int zmq::ws_engine_t::process_command_message (msg_t *msg_)
991 {
992     if (msg_->is_ping ()) {
993         _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
994           &ws_engine_t::produce_pong_message);
995         out_event ();
996     } else if (msg_->is_close_cmd ()) {
997         int rc = _close_msg.copy (*msg_);
998         errno_assert (rc == 0);
999         _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
1000           &ws_engine_t::produce_close_message);
1001         out_event ();
1002     }
1003 
1004     return 0;
1005 }
1006 
1007 static int
encode_base64(const unsigned char * in_,int in_len_,char * out_,int out_len_)1008 encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_)
1009 {
1010     static const unsigned char base64enc_tab[65] =
1011       "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1012 
1013     int io = 0;
1014     uint32_t v = 0;
1015     int rem = 0;
1016 
1017     for (int ii = 0; ii < in_len_; ii++) {
1018         unsigned char ch;
1019         ch = in_[ii];
1020         v = (v << 8) | ch;
1021         rem += 8;
1022         while (rem >= 6) {
1023             rem -= 6;
1024             if (io >= out_len_)
1025                 return -1; /* truncation is failure */
1026             out_[io++] = base64enc_tab[(v >> rem) & 63];
1027         }
1028     }
1029     if (rem) {
1030         v <<= (6 - rem);
1031         if (io >= out_len_)
1032             return -1; /* truncation is failure */
1033         out_[io++] = base64enc_tab[v & 63];
1034     }
1035     while (io & 3) {
1036         if (io >= out_len_)
1037             return -1; /* truncation is failure */
1038         out_[io++] = '=';
1039     }
1040     if (io >= out_len_)
1041         return -1; /* no room for null terminator */
1042     out_[io] = 0;
1043     return io;
1044 }
1045 
compute_accept_key(char * key_,unsigned char * hash_)1046 static void compute_accept_key (char *key_, unsigned char *hash_)
1047 {
1048     const char *magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1049 #ifdef ZMQ_USE_NSS
1050     unsigned int len;
1051     HASH_HashType type = HASH_GetHashTypeByOidTag (SEC_OID_SHA1);
1052     HASHContext *ctx = HASH_Create (type);
1053     assert (ctx);
1054 
1055     HASH_Begin (ctx);
1056     HASH_Update (ctx, (unsigned char *) key_, (unsigned int) strlen (key_));
1057     HASH_Update (ctx, (unsigned char *) magic_string,
1058                  (unsigned int) strlen (magic_string));
1059     HASH_End (ctx, hash_, &len, SHA_DIGEST_LENGTH);
1060     HASH_Destroy (ctx);
1061 #elif defined ZMQ_USE_BUILTIN_SHA1
1062     sha1_ctxt ctx;
1063     SHA1_Init (&ctx);
1064     SHA1_Update (&ctx, (unsigned char *) key_, strlen (key_));
1065     SHA1_Update (&ctx, (unsigned char *) magic_string, strlen (magic_string));
1066 
1067     SHA1_Final (hash_, &ctx);
1068 #elif defined ZMQ_USE_GNUTLS
1069     gnutls_hash_hd_t hd;
1070     gnutls_hash_init (&hd, GNUTLS_DIG_SHA1);
1071     gnutls_hash (hd, key_, strlen (key_));
1072     gnutls_hash (hd, magic_string, strlen (magic_string));
1073     gnutls_hash_deinit (hd, hash_);
1074 #else
1075 #error "No sha1 implementation set"
1076 #endif
1077 }
1078