1 /*
2     Copyright (c) 2007-2016 Contributors as noted in the AUTHORS file
3 
4     This file is part of libzmq, the ZeroMQ core engine in C++.
5 
6     libzmq is free software; you can redistribute it and/or modify it under
7     the terms of the GNU Lesser General Public License (LGPL) as published
8     by the Free Software Foundation; either version 3 of the License, or
9     (at your option) any later version.
10 
11     As a special exception, the Contributors give you permission to link
12     this library with independent modules to produce an executable,
13     regardless of the license terms of these independent modules, and to
14     copy and distribute the resulting executable under terms of your choice,
15     provided that you also meet, for each linked independent module, the
16     terms and conditions of the license of that module. An independent
17     module is a module which is not derived from or based on this library.
18     If you modify this library, you must extend this exception to your
19     version of the library.
20 
21     libzmq is distributed in the hope that it will be useful, but WITHOUT
22     ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
23     FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
24     License for more details.
25 
26     You should have received a copy of the GNU Lesser General Public License
27     along with this program.  If not, see <http://www.gnu.org/licenses/>.
28 */
29 
30 #include "precompiled.hpp"
31 #include "macros.hpp"
32 
33 #ifdef ZMQ_HAVE_CURVE
34 
35 #include "msg.hpp"
36 #include "session_base.hpp"
37 #include "err.hpp"
38 #include "curve_client.hpp"
39 #include "wire.hpp"
40 #include "curve_client_tools.hpp"
41 #include "secure_allocator.hpp"
42 
curve_client_t(session_base_t * session_,const options_t & options_,const bool downgrade_sub_)43 zmq::curve_client_t::curve_client_t (session_base_t *session_,
44                                      const options_t &options_,
45                                      const bool downgrade_sub_) :
46     mechanism_base_t (session_, options_),
47     curve_mechanism_base_t (session_,
48                             options_,
49                             "CurveZMQMESSAGEC",
50                             "CurveZMQMESSAGES",
51                             downgrade_sub_),
52     _state (send_hello),
53     _tools (options_.curve_public_key,
54             options_.curve_secret_key,
55             options_.curve_server_key)
56 {
57 }
58 
~curve_client_t()59 zmq::curve_client_t::~curve_client_t ()
60 {
61 }
62 
next_handshake_command(msg_t * msg_)63 int zmq::curve_client_t::next_handshake_command (msg_t *msg_)
64 {
65     int rc = 0;
66 
67     switch (_state) {
68         case send_hello:
69             rc = produce_hello (msg_);
70             if (rc == 0)
71                 _state = expect_welcome;
72             break;
73         case send_initiate:
74             rc = produce_initiate (msg_);
75             if (rc == 0)
76                 _state = expect_ready;
77             break;
78         default:
79             errno = EAGAIN;
80             rc = -1;
81     }
82     return rc;
83 }
84 
process_handshake_command(msg_t * msg_)85 int zmq::curve_client_t::process_handshake_command (msg_t *msg_)
86 {
87     const unsigned char *msg_data =
88       static_cast<unsigned char *> (msg_->data ());
89     const size_t msg_size = msg_->size ();
90 
91     int rc = 0;
92     if (curve_client_tools_t::is_handshake_command_welcome (msg_data, msg_size))
93         rc = process_welcome (msg_data, msg_size);
94     else if (curve_client_tools_t::is_handshake_command_ready (msg_data,
95                                                                msg_size))
96         rc = process_ready (msg_data, msg_size);
97     else if (curve_client_tools_t::is_handshake_command_error (msg_data,
98                                                                msg_size))
99         rc = process_error (msg_data, msg_size);
100     else {
101         session->get_socket ()->event_handshake_failed_protocol (
102           session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
103         errno = EPROTO;
104         rc = -1;
105     }
106 
107     if (rc == 0) {
108         rc = msg_->close ();
109         errno_assert (rc == 0);
110         rc = msg_->init ();
111         errno_assert (rc == 0);
112     }
113 
114     return rc;
115 }
116 
encode(msg_t * msg_)117 int zmq::curve_client_t::encode (msg_t *msg_)
118 {
119     zmq_assert (_state == connected);
120     return curve_mechanism_base_t::encode (msg_);
121 }
122 
decode(msg_t * msg_)123 int zmq::curve_client_t::decode (msg_t *msg_)
124 {
125     zmq_assert (_state == connected);
126     return curve_mechanism_base_t::decode (msg_);
127 }
128 
status() const129 zmq::mechanism_t::status_t zmq::curve_client_t::status () const
130 {
131     if (_state == connected)
132         return mechanism_t::ready;
133     if (_state == error_received)
134         return mechanism_t::error;
135 
136     return mechanism_t::handshaking;
137 }
138 
produce_hello(msg_t * msg_)139 int zmq::curve_client_t::produce_hello (msg_t *msg_)
140 {
141     int rc = msg_->init_size (200);
142     errno_assert (rc == 0);
143 
144     rc = _tools.produce_hello (msg_->data (), get_and_inc_nonce ());
145     if (rc == -1) {
146         session->get_socket ()->event_handshake_failed_protocol (
147           session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
148 
149         // TODO this is somewhat inconsistent: we call init_size, but we may
150         // not close msg_; i.e. we assume that msg_ is initialized but empty
151         // (if it were non-empty, calling init_size might cause a leak!)
152 
153         // msg_->close ();
154         return -1;
155     }
156 
157     return 0;
158 }
159 
process_welcome(const uint8_t * msg_data_,size_t msg_size_)160 int zmq::curve_client_t::process_welcome (const uint8_t *msg_data_,
161                                           size_t msg_size_)
162 {
163     const int rc = _tools.process_welcome (msg_data_, msg_size_,
164                                            get_writable_precom_buffer ());
165 
166     if (rc == -1) {
167         session->get_socket ()->event_handshake_failed_protocol (
168           session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
169 
170         errno = EPROTO;
171         return -1;
172     }
173 
174     _state = send_initiate;
175 
176     return 0;
177 }
178 
produce_initiate(msg_t * msg_)179 int zmq::curve_client_t::produce_initiate (msg_t *msg_)
180 {
181     const size_t metadata_length = basic_properties_len ();
182     std::vector<unsigned char, secure_allocator_t<unsigned char> >
183       metadata_plaintext (metadata_length);
184 
185     add_basic_properties (&metadata_plaintext[0], metadata_length);
186 
187     const size_t msg_size =
188       113 + 128 + crypto_box_BOXZEROBYTES + metadata_length;
189     int rc = msg_->init_size (msg_size);
190     errno_assert (rc == 0);
191 
192     rc = _tools.produce_initiate (msg_->data (), msg_size, get_and_inc_nonce (),
193                                   &metadata_plaintext[0], metadata_length);
194 
195     if (-1 == rc) {
196         session->get_socket ()->event_handshake_failed_protocol (
197           session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
198 
199         // TODO see comment in produce_hello
200         return -1;
201     }
202 
203     return 0;
204 }
205 
process_ready(const uint8_t * msg_data_,size_t msg_size_)206 int zmq::curve_client_t::process_ready (const uint8_t *msg_data_,
207                                         size_t msg_size_)
208 {
209     if (msg_size_ < 30) {
210         session->get_socket ()->event_handshake_failed_protocol (
211           session->get_endpoint (),
212           ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY);
213         errno = EPROTO;
214         return -1;
215     }
216 
217     const size_t clen = (msg_size_ - 14) + crypto_box_BOXZEROBYTES;
218 
219     uint8_t ready_nonce[crypto_box_NONCEBYTES];
220     std::vector<uint8_t, secure_allocator_t<uint8_t> > ready_plaintext (
221       crypto_box_ZEROBYTES + clen);
222     std::vector<uint8_t> ready_box (crypto_box_BOXZEROBYTES + 16 + clen);
223 
224     std::fill (ready_box.begin (), ready_box.begin () + crypto_box_BOXZEROBYTES,
225                0);
226     memcpy (&ready_box[crypto_box_BOXZEROBYTES], msg_data_ + 14,
227             clen - crypto_box_BOXZEROBYTES);
228 
229     memcpy (ready_nonce, "CurveZMQREADY---", 16);
230     memcpy (ready_nonce + 16, msg_data_ + 6, 8);
231     set_peer_nonce (get_uint64 (msg_data_ + 6));
232 
233     int rc = crypto_box_open_afternm (&ready_plaintext[0], &ready_box[0], clen,
234                                       ready_nonce, get_precom_buffer ());
235 
236     if (rc != 0) {
237         session->get_socket ()->event_handshake_failed_protocol (
238           session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
239         errno = EPROTO;
240         return -1;
241     }
242 
243     rc = parse_metadata (&ready_plaintext[crypto_box_ZEROBYTES],
244                          clen - crypto_box_ZEROBYTES);
245 
246     if (rc == 0)
247         _state = connected;
248     else {
249         session->get_socket ()->event_handshake_failed_protocol (
250           session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA);
251         errno = EPROTO;
252     }
253 
254     return rc;
255 }
256 
process_error(const uint8_t * msg_data_,size_t msg_size_)257 int zmq::curve_client_t::process_error (const uint8_t *msg_data_,
258                                         size_t msg_size_)
259 {
260     if (_state != expect_welcome && _state != expect_ready) {
261         session->get_socket ()->event_handshake_failed_protocol (
262           session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
263         errno = EPROTO;
264         return -1;
265     }
266     if (msg_size_ < 7) {
267         session->get_socket ()->event_handshake_failed_protocol (
268           session->get_endpoint (),
269           ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
270         errno = EPROTO;
271         return -1;
272     }
273     const size_t error_reason_len = static_cast<size_t> (msg_data_[6]);
274     if (error_reason_len > msg_size_ - 7) {
275         session->get_socket ()->event_handshake_failed_protocol (
276           session->get_endpoint (),
277           ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
278         errno = EPROTO;
279         return -1;
280     }
281     const char *error_reason = reinterpret_cast<const char *> (msg_data_) + 7;
282     handle_error_reason (error_reason, error_reason_len);
283     _state = error_received;
284     return 0;
285 }
286 
287 #endif
288