1 /**
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  * SPDX-License-Identifier: Apache-2.0.
4  */
5 
6 #include <aws/mqtt/private/client_impl.h>
7 
8 #include <aws/mqtt/private/packets.h>
9 #include <aws/mqtt/private/topic_tree.h>
10 
11 #include <aws/io/logging.h>
12 
13 #include <aws/common/clock.h>
14 #include <aws/common/math.h>
15 #include <aws/common/task_scheduler.h>
16 
17 #include <inttypes.h>
18 
19 #ifdef _MSC_VER
20 #    pragma warning(disable : 4204)
21 #endif
22 
23 /*******************************************************************************
24  * Packet State Machine
25  ******************************************************************************/
26 
27 typedef int(packet_handler_fn)(struct aws_mqtt_client_connection *connection, struct aws_byte_cursor message_cursor);
28 
s_packet_handler_default(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)29 static int s_packet_handler_default(
30     struct aws_mqtt_client_connection *connection,
31     struct aws_byte_cursor message_cursor) {
32     (void)connection;
33     (void)message_cursor;
34 
35     AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Unhandled packet type received", (void *)connection);
36     return aws_raise_error(AWS_ERROR_MQTT_INVALID_PACKET_TYPE);
37 }
38 
39 static void s_on_time_to_ping(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status);
s_schedule_ping(struct aws_mqtt_client_connection * connection)40 static void s_schedule_ping(struct aws_mqtt_client_connection *connection) {
41     aws_channel_task_init(&connection->ping_task, s_on_time_to_ping, connection, "mqtt_ping");
42 
43     uint64_t now = 0;
44     aws_channel_current_clock_time(connection->slot->channel, &now);
45     AWS_LOGF_TRACE(
46         AWS_LS_MQTT_CLIENT, "id=%p: Scheduling PING. current timestamp is %" PRIu64, (void *)connection, now);
47 
48     uint64_t schedule_time =
49         now + aws_timestamp_convert(connection->keep_alive_time_secs, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL);
50 
51     AWS_LOGF_TRACE(
52         AWS_LS_MQTT_CLIENT,
53         "id=%p: The next ping will be run at timestamp %" PRIu64,
54         (void *)connection,
55         schedule_time);
56     aws_channel_schedule_task_future(connection->slot->channel, &connection->ping_task, schedule_time);
57 }
58 
s_on_time_to_ping(struct aws_channel_task * channel_task,void * arg,enum aws_task_status status)59 static void s_on_time_to_ping(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
60     (void)channel_task;
61 
62     if (status == AWS_TASK_STATUS_RUN_READY) {
63         struct aws_mqtt_client_connection *connection = arg;
64         AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Sending PING", (void *)connection);
65         aws_mqtt_client_connection_ping(connection);
66         s_schedule_ping(connection);
67     }
68 }
s_packet_handler_connack(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)69 static int s_packet_handler_connack(
70     struct aws_mqtt_client_connection *connection,
71     struct aws_byte_cursor message_cursor) {
72 
73     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: CONNACK received", (void *)connection);
74 
75     struct aws_mqtt_packet_connack connack;
76     if (aws_mqtt_packet_connack_decode(&message_cursor, &connack)) {
77         AWS_LOGF_ERROR(
78             AWS_LS_MQTT_CLIENT, "id=%p: error %d parsing CONNACK packet", (void *)connection, aws_last_error());
79 
80         return AWS_OP_ERR;
81     }
82     bool was_reconnecting;
83     struct aws_linked_list requests;
84     aws_linked_list_init(&requests);
85     { /* BEGIN CRITICAL SECTION */
86         mqtt_connection_lock_synced_data(connection);
87         /* User requested disconnect, don't do anything */
88         if (connection->synced_data.state >= AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
89             mqtt_connection_unlock_synced_data(connection);
90             AWS_LOGF_TRACE(
91                 AWS_LS_MQTT_CLIENT, "id=%p: User has requested disconnect, dropping connection", (void *)connection);
92             return AWS_OP_SUCCESS;
93         }
94 
95         was_reconnecting = connection->synced_data.state == AWS_MQTT_CLIENT_STATE_RECONNECTING;
96         if (connack.connect_return_code == AWS_MQTT_CONNECT_ACCEPTED) {
97             AWS_LOGF_DEBUG(
98                 AWS_LS_MQTT_CLIENT,
99                 "id=%p: connection was accepted, switch state from %d to CONNECTED.",
100                 (void *)connection,
101                 (int)connection->synced_data.state);
102             /* Don't change the state if it's not ACCEPTED by broker */
103             mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_CONNECTED);
104             aws_linked_list_swap_contents(&connection->synced_data.pending_requests_list, &requests);
105         }
106         mqtt_connection_unlock_synced_data(connection);
107     } /* END CRITICAL SECTION */
108     connection->connection_count++;
109 
110     /* Reset the current timeout timer */
111     connection->reconnect_timeouts.current = connection->reconnect_timeouts.min;
112 
113     if (connack.connect_return_code == AWS_MQTT_CONNECT_ACCEPTED) {
114         /* If successfully connected, schedule all pending tasks */
115         AWS_LOGF_TRACE(
116             AWS_LS_MQTT_CLIENT, "id=%p: connection was accepted processing offline requests.", (void *)connection);
117 
118         if (!aws_linked_list_empty(&requests)) {
119 
120             struct aws_linked_list_node *current = aws_linked_list_front(&requests);
121             const struct aws_linked_list_node *end = aws_linked_list_end(&requests);
122 
123             do {
124                 struct aws_mqtt_request *request = AWS_CONTAINER_OF(current, struct aws_mqtt_request, list_node);
125                 AWS_LOGF_TRACE(
126                     AWS_LS_MQTT_CLIENT,
127                     "id=%p: processing offline request %" PRIu16,
128                     (void *)connection,
129                     request->packet_id);
130                 aws_channel_schedule_task_now(connection->slot->channel, &request->outgoing_task);
131                 current = current->next;
132             } while (current != end);
133         }
134     } else {
135         AWS_LOGF_ERROR(
136             AWS_LS_MQTT_CLIENT,
137             "id=%p: invalid connect return code %d, disconnecting",
138             (void *)connection,
139             connack.connect_return_code);
140         /* If error code returned, disconnect, on_completed will be invoked from shutdown process */
141         aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_PROTOCOL_ERROR);
142 
143         return AWS_OP_SUCCESS;
144     }
145 
146     /* It is possible for a connection to complete, and a hangup to occur before the
147      * CONNECT/CONNACK cycle completes. In that case, we must deliver on_connection_complete
148      * on the first successful CONNACK or user code will never think it's connected */
149     if (was_reconnecting && connection->connection_count > 1) {
150 
151         AWS_LOGF_TRACE(
152             AWS_LS_MQTT_CLIENT,
153             "id=%p: connection is a resumed connection, invoking on_resumed callback",
154             (void *)connection);
155 
156         MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_resumed, connack.connect_return_code, connack.session_present);
157     } else {
158 
159         aws_create_reconnect_task(connection);
160 
161         AWS_LOGF_TRACE(
162             AWS_LS_MQTT_CLIENT,
163             "id=%p: connection is a new connection, invoking on_connection_complete callback",
164             (void *)connection);
165         MQTT_CLIENT_CALL_CALLBACK_ARGS(
166             connection, on_connection_complete, AWS_OP_SUCCESS, connack.connect_return_code, connack.session_present);
167     }
168 
169     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: connection callback completed", (void *)connection);
170 
171     s_schedule_ping(connection);
172     return AWS_OP_SUCCESS;
173 }
174 
s_packet_handler_publish(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)175 static int s_packet_handler_publish(
176     struct aws_mqtt_client_connection *connection,
177     struct aws_byte_cursor message_cursor) {
178 
179     /* TODO: need to handle the QoS 2 message to avoid processing the message a second time */
180     struct aws_mqtt_packet_publish publish;
181     if (aws_mqtt_packet_publish_decode(&message_cursor, &publish)) {
182         return AWS_OP_ERR;
183     }
184 
185     aws_mqtt_topic_tree_publish(&connection->thread_data.subscriptions, &publish);
186 
187     bool dup = aws_mqtt_packet_publish_get_dup(&publish);
188     enum aws_mqtt_qos qos = aws_mqtt_packet_publish_get_qos(&publish);
189     bool retain = aws_mqtt_packet_publish_get_retain(&publish);
190 
191     MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_any_publish, &publish.topic_name, &publish.payload, dup, qos, retain);
192 
193     AWS_LOGF_TRACE(
194         AWS_LS_MQTT_CLIENT,
195         "id=%p: publish received with msg id=%" PRIu16 " dup=%d qos=%d retain=%d payload-size=%zu topic=" PRInSTR,
196         (void *)connection,
197         publish.packet_identifier,
198         dup,
199         qos,
200         retain,
201         publish.payload.len,
202         AWS_BYTE_CURSOR_PRI(publish.topic_name));
203     struct aws_mqtt_packet_ack puback;
204     AWS_ZERO_STRUCT(puback);
205 
206     /* Switch on QoS flags (bits 1 & 2) */
207     switch (qos) {
208         case AWS_MQTT_QOS_AT_MOST_ONCE:
209             AWS_LOGF_TRACE(
210                 AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 0, not sending puback", (void *)connection);
211             /* No more communication necessary */
212             break;
213         case AWS_MQTT_QOS_AT_LEAST_ONCE:
214             AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 1, sending puback", (void *)connection);
215             aws_mqtt_packet_puback_init(&puback, publish.packet_identifier);
216             break;
217         case AWS_MQTT_QOS_EXACTLY_ONCE:
218             AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 2, sending pubrec", (void *)connection);
219             aws_mqtt_packet_pubrec_init(&puback, publish.packet_identifier);
220             break;
221         default:
222             /* Impossible to hit this branch. QoS value is checked when decoding */
223             AWS_FATAL_ASSERT(0);
224             break;
225     }
226 
227     if (puback.packet_identifier) {
228         struct aws_io_message *message = mqtt_get_message_for_packet(connection, &puback.fixed_header);
229         if (!message) {
230             return AWS_OP_ERR;
231         }
232 
233         if (aws_mqtt_packet_ack_encode(&message->message_data, &puback)) {
234             aws_mem_release(message->allocator, message);
235             return AWS_OP_ERR;
236         }
237 
238         if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
239             aws_mem_release(message->allocator, message);
240             return AWS_OP_ERR;
241         }
242     }
243 
244     return AWS_OP_SUCCESS;
245 }
246 
s_packet_handler_ack(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)247 static int s_packet_handler_ack(struct aws_mqtt_client_connection *connection, struct aws_byte_cursor message_cursor) {
248     struct aws_mqtt_packet_ack ack;
249     if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
250         return AWS_OP_ERR;
251     }
252 
253     AWS_LOGF_TRACE(
254         AWS_LS_MQTT_CLIENT, "id=%p: received ack for message id %" PRIu16, (void *)connection, ack.packet_identifier);
255 
256     mqtt_request_complete(connection, AWS_ERROR_SUCCESS, ack.packet_identifier);
257 
258     return AWS_OP_SUCCESS;
259 }
260 
s_packet_handler_suback(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)261 static int s_packet_handler_suback(
262     struct aws_mqtt_client_connection *connection,
263     struct aws_byte_cursor message_cursor) {
264     struct aws_mqtt_packet_suback suback;
265     if (aws_mqtt_packet_suback_init(&suback, connection->allocator, 0 /* fake packet_id */)) {
266         return AWS_OP_ERR;
267     }
268 
269     if (aws_mqtt_packet_suback_decode(&message_cursor, &suback)) {
270         goto error;
271     }
272 
273     AWS_LOGF_TRACE(
274         AWS_LS_MQTT_CLIENT,
275         "id=%p: received suback for message id %" PRIu16,
276         (void *)connection,
277         suback.packet_identifier);
278 
279     struct aws_mqtt_request *request = NULL;
280 
281     { /* BEGIN CRITICAL SECTION */
282         mqtt_connection_lock_synced_data(connection);
283         struct aws_hash_element *elem = NULL;
284         aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &suback.packet_identifier, &elem);
285         if (elem != NULL) {
286             request = elem->value;
287         }
288         mqtt_connection_unlock_synced_data(connection);
289     } /* END CRITICAL SECTION */
290 
291     if (request == NULL) {
292         /* no corresponding request found */
293         goto done;
294     }
295 
296     struct subscribe_task_arg *task_arg = request->on_complete_ud;
297     size_t request_topics_len = aws_array_list_length(&task_arg->topics);
298     size_t suback_return_code_len = aws_array_list_length(&suback.return_codes);
299     if (request_topics_len != suback_return_code_len) {
300         goto error;
301     }
302     size_t num_filters = aws_array_list_length(&suback.return_codes);
303     for (size_t i = 0; i < num_filters; ++i) {
304 
305         uint8_t return_code = 0;
306         struct subscribe_task_topic *topic = NULL;
307         aws_array_list_get_at(&suback.return_codes, (void *)&return_code, i);
308         aws_array_list_get_at(&task_arg->topics, &topic, i);
309         topic->request.qos = return_code;
310     }
311 
312 done:
313     mqtt_request_complete(connection, AWS_ERROR_SUCCESS, suback.packet_identifier);
314     aws_mqtt_packet_suback_clean_up(&suback);
315     return AWS_OP_SUCCESS;
316 error:
317     aws_mqtt_packet_suback_clean_up(&suback);
318     return AWS_OP_ERR;
319 }
320 
s_packet_handler_pubrec(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)321 static int s_packet_handler_pubrec(
322     struct aws_mqtt_client_connection *connection,
323     struct aws_byte_cursor message_cursor) {
324 
325     struct aws_mqtt_packet_ack ack;
326     if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
327         return AWS_OP_ERR;
328     }
329 
330     /* TODO: When sending PUBLISH with QoS 2, we should be storing the data until this packet is received, at which
331      * point we may discard it. */
332 
333     /* Send PUBREL */
334     aws_mqtt_packet_pubrel_init(&ack, ack.packet_identifier);
335     struct aws_io_message *message = mqtt_get_message_for_packet(connection, &ack.fixed_header);
336     if (!message) {
337         return AWS_OP_ERR;
338     }
339 
340     if (aws_mqtt_packet_ack_encode(&message->message_data, &ack)) {
341         goto on_error;
342     }
343 
344     if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
345         goto on_error;
346     }
347 
348     return AWS_OP_SUCCESS;
349 
350 on_error:
351 
352     if (message) {
353         aws_mem_release(message->allocator, message);
354     }
355 
356     return AWS_OP_ERR;
357 }
358 
s_packet_handler_pubrel(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)359 static int s_packet_handler_pubrel(
360     struct aws_mqtt_client_connection *connection,
361     struct aws_byte_cursor message_cursor) {
362 
363     struct aws_mqtt_packet_ack ack;
364     if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
365         return AWS_OP_ERR;
366     }
367 
368     /* Send PUBCOMP */
369     aws_mqtt_packet_pubcomp_init(&ack, ack.packet_identifier);
370     struct aws_io_message *message = mqtt_get_message_for_packet(connection, &ack.fixed_header);
371     if (!message) {
372         return AWS_OP_ERR;
373     }
374 
375     if (aws_mqtt_packet_ack_encode(&message->message_data, &ack)) {
376         goto on_error;
377     }
378 
379     if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
380         goto on_error;
381     }
382 
383     return AWS_OP_SUCCESS;
384 
385 on_error:
386 
387     if (message) {
388         aws_mem_release(message->allocator, message);
389     }
390 
391     return AWS_OP_ERR;
392 }
393 
s_packet_handler_pingresp(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)394 static int s_packet_handler_pingresp(
395     struct aws_mqtt_client_connection *connection,
396     struct aws_byte_cursor message_cursor) {
397 
398     (void)message_cursor;
399 
400     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: PINGRESP received", (void *)connection);
401 
402     connection->thread_data.waiting_on_ping_response = false;
403 
404     return AWS_OP_SUCCESS;
405 }
406 
407 /* Bake up a big ol' function table just like Gramma used to make */
408 static packet_handler_fn *s_packet_handlers[] = {
409     [AWS_MQTT_PACKET_CONNECT] = &s_packet_handler_default,
410     [AWS_MQTT_PACKET_CONNACK] = &s_packet_handler_connack,
411     [AWS_MQTT_PACKET_PUBLISH] = &s_packet_handler_publish,
412     [AWS_MQTT_PACKET_PUBACK] = &s_packet_handler_ack,
413     [AWS_MQTT_PACKET_PUBREC] = &s_packet_handler_pubrec,
414     [AWS_MQTT_PACKET_PUBREL] = &s_packet_handler_pubrel,
415     [AWS_MQTT_PACKET_PUBCOMP] = &s_packet_handler_ack,
416     [AWS_MQTT_PACKET_SUBSCRIBE] = &s_packet_handler_default,
417     [AWS_MQTT_PACKET_SUBACK] = &s_packet_handler_suback,
418     [AWS_MQTT_PACKET_UNSUBSCRIBE] = &s_packet_handler_default,
419     [AWS_MQTT_PACKET_UNSUBACK] = &s_packet_handler_ack,
420     [AWS_MQTT_PACKET_PINGREQ] = &s_packet_handler_default,
421     [AWS_MQTT_PACKET_PINGRESP] = &s_packet_handler_pingresp,
422     [AWS_MQTT_PACKET_DISCONNECT] = &s_packet_handler_default,
423 };
424 
425 /*******************************************************************************
426  * Channel Handler
427  ******************************************************************************/
428 
s_process_mqtt_packet(struct aws_mqtt_client_connection * connection,enum aws_mqtt_packet_type packet_type,struct aws_byte_cursor packet)429 static int s_process_mqtt_packet(
430     struct aws_mqtt_client_connection *connection,
431     enum aws_mqtt_packet_type packet_type,
432     struct aws_byte_cursor packet) {
433     { /* BEGIN CRITICAL SECTION */
434         mqtt_connection_lock_synced_data(connection);
435         /* [MQTT-3.2.0-1] The first packet sent from the Server to the Client MUST be a CONNACK Packet */
436         if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_CONNECTING &&
437             packet_type != AWS_MQTT_PACKET_CONNACK) {
438             mqtt_connection_unlock_synced_data(connection);
439             AWS_LOGF_ERROR(
440                 AWS_LS_MQTT_CLIENT,
441                 "id=%p: First message received from the server was not a CONNACK. Terminating connection.",
442                 (void *)connection);
443             aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_PROTOCOL_ERROR);
444             return aws_raise_error(AWS_ERROR_MQTT_PROTOCOL_ERROR);
445         }
446         mqtt_connection_unlock_synced_data(connection);
447     } /* END CRITICAL SECTION */
448 
449     if (AWS_UNLIKELY(packet_type > AWS_MQTT_PACKET_DISCONNECT || packet_type < AWS_MQTT_PACKET_CONNECT)) {
450         AWS_LOGF_ERROR(
451             AWS_LS_MQTT_CLIENT,
452             "id=%p: Invalid packet type received %d. Terminating connection.",
453             (void *)connection,
454             packet_type);
455         return aws_raise_error(AWS_ERROR_MQTT_INVALID_PACKET_TYPE);
456     }
457 
458     /* Handle the packet */
459     return s_packet_handlers[packet_type](connection, packet);
460 }
461 
462 /**
463  * Handles incoming messages from the server.
464  */
s_process_read_message(struct aws_channel_handler * handler,struct aws_channel_slot * slot,struct aws_io_message * message)465 static int s_process_read_message(
466     struct aws_channel_handler *handler,
467     struct aws_channel_slot *slot,
468     struct aws_io_message *message) {
469 
470     struct aws_mqtt_client_connection *connection = handler->impl;
471 
472     if (message->message_type != AWS_IO_MESSAGE_APPLICATION_DATA || message->message_data.len < 1) {
473         return AWS_OP_ERR;
474     }
475 
476     AWS_LOGF_TRACE(
477         AWS_LS_MQTT_CLIENT,
478         "id=%p: precessing read message of size %zu",
479         (void *)connection,
480         message->message_data.len);
481 
482     /* This cursor will be updated as we read through the message. */
483     struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data);
484 
485     /* If there's pending packet left over from last time, attempt to complete it. */
486     if (connection->thread_data.pending_packet.len) {
487         int result = AWS_OP_SUCCESS;
488 
489         /* This determines how much to read from the message (min(expected_remaining, message.len)) */
490         size_t to_read = connection->thread_data.pending_packet.capacity - connection->thread_data.pending_packet.len;
491         /* This will be set to false if this message still won't complete the packet object. */
492         bool packet_complete = true;
493         if (to_read > message_cursor.len) {
494             to_read = message_cursor.len;
495             packet_complete = false;
496         }
497 
498         /* Write the chunk to the buffer.
499          * This will either complete the packet, or be the entirety of message if more data is required. */
500         struct aws_byte_cursor chunk = aws_byte_cursor_advance(&message_cursor, to_read);
501         AWS_ASSERT(chunk.ptr); /* Guaranteed to be in bounds */
502         result = (int)aws_byte_buf_write_from_whole_cursor(&connection->thread_data.pending_packet, chunk) - 1;
503         if (result) {
504             goto handle_error;
505         }
506 
507         /* If the packet is still incomplete, don't do anything with the data. */
508         if (!packet_complete) {
509             AWS_LOGF_TRACE(
510                 AWS_LS_MQTT_CLIENT,
511                 "id=%p: partial message is still incomplete, waiting on another read.",
512                 (void *)connection);
513 
514             goto cleanup;
515         }
516 
517         /* Handle the completed pending packet */
518         struct aws_byte_cursor packet_data = aws_byte_cursor_from_buf(&connection->thread_data.pending_packet);
519         AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: full mqtt packet re-assembled, dispatching.", (void *)connection);
520         result = s_process_mqtt_packet(connection, aws_mqtt_get_packet_type(packet_data.ptr), packet_data);
521 
522     handle_error:
523         /* Clean up the pending packet */
524         aws_byte_buf_clean_up(&connection->thread_data.pending_packet);
525         AWS_ZERO_STRUCT(connection->thread_data.pending_packet);
526 
527         if (result) {
528             return AWS_OP_ERR;
529         }
530     }
531 
532     while (message_cursor.len) {
533 
534         /* Temp byte cursor so we can decode the header without advancing message_cursor. */
535         struct aws_byte_cursor header_decode = message_cursor;
536 
537         struct aws_mqtt_fixed_header packet_header;
538         AWS_ZERO_STRUCT(packet_header);
539         int result = aws_mqtt_fixed_header_decode(&header_decode, &packet_header);
540 
541         /* Calculate how much data was read. */
542         const size_t fixed_header_size = message_cursor.len - header_decode.len;
543 
544         if (result) {
545             if (aws_last_error() == AWS_ERROR_SHORT_BUFFER) {
546                 /* Message data too short, store data and come back later. */
547                 AWS_LOGF_TRACE(
548                     AWS_LS_MQTT_CLIENT, "id=%p: message is incomplete, waiting on another read.", (void *)connection);
549                 if (aws_byte_buf_init(
550                         &connection->thread_data.pending_packet,
551                         connection->allocator,
552                         fixed_header_size + packet_header.remaining_length)) {
553 
554                     return AWS_OP_ERR;
555                 }
556 
557                 /* Write the partial packet. */
558                 if (!aws_byte_buf_write_from_whole_cursor(&connection->thread_data.pending_packet, message_cursor)) {
559                     aws_byte_buf_clean_up(&connection->thread_data.pending_packet);
560                     return AWS_OP_ERR;
561                 }
562 
563                 aws_reset_error();
564                 goto cleanup;
565             } else {
566                 return AWS_OP_ERR;
567             }
568         }
569 
570         struct aws_byte_cursor packet_data =
571             aws_byte_cursor_advance(&message_cursor, fixed_header_size + packet_header.remaining_length);
572         AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: full mqtt packet read, dispatching.", (void *)connection);
573         s_process_mqtt_packet(connection, packet_header.packet_type, packet_data);
574     }
575 
576 cleanup:
577     /* Do cleanup */
578     aws_channel_slot_increment_read_window(slot, message->message_data.len);
579     aws_mem_release(message->allocator, message);
580 
581     return AWS_OP_SUCCESS;
582 }
583 
s_shutdown(struct aws_channel_handler * handler,struct aws_channel_slot * slot,enum aws_channel_direction dir,int error_code,bool free_scarce_resources_immediately)584 static int s_shutdown(
585     struct aws_channel_handler *handler,
586     struct aws_channel_slot *slot,
587     enum aws_channel_direction dir,
588     int error_code,
589     bool free_scarce_resources_immediately) {
590 
591     struct aws_mqtt_client_connection *connection = handler->impl;
592 
593     if (dir == AWS_CHANNEL_DIR_WRITE) {
594         /* On closing write direction, send out disconnect packet before closing connection. */
595 
596         if (!free_scarce_resources_immediately) {
597 
598             if (error_code == AWS_OP_SUCCESS) {
599                 AWS_LOGF_INFO(
600                     AWS_LS_MQTT_CLIENT,
601                     "id=%p: sending disconnect message as part of graceful shutdown.",
602                     (void *)connection);
603                 /* On clean shutdown, send the disconnect message */
604                 struct aws_mqtt_packet_connection disconnect;
605                 aws_mqtt_packet_disconnect_init(&disconnect);
606 
607                 struct aws_io_message *message = mqtt_get_message_for_packet(connection, &disconnect.fixed_header);
608                 if (!message) {
609                     goto done;
610                 }
611 
612                 if (aws_mqtt_packet_connection_encode(&message->message_data, &disconnect)) {
613                     AWS_LOGF_DEBUG(
614                         AWS_LS_MQTT_CLIENT,
615                         "id=%p: failed to encode courteous disconnect io message",
616                         (void *)connection);
617                     aws_mem_release(message->allocator, message);
618                     goto done;
619                 }
620 
621                 if (aws_channel_slot_send_message(slot, message, AWS_CHANNEL_DIR_WRITE)) {
622                     AWS_LOGF_DEBUG(
623                         AWS_LS_MQTT_CLIENT,
624                         "id=%p: failed to send courteous disconnect io message",
625                         (void *)connection);
626                     aws_mem_release(message->allocator, message);
627                     goto done;
628                 }
629             }
630         }
631     }
632 
633 done:
634     return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, free_scarce_resources_immediately);
635 }
636 
s_initial_window_size(struct aws_channel_handler * handler)637 static size_t s_initial_window_size(struct aws_channel_handler *handler) {
638 
639     (void)handler;
640 
641     return SIZE_MAX;
642 }
643 
s_destroy(struct aws_channel_handler * handler)644 static void s_destroy(struct aws_channel_handler *handler) {
645 
646     struct aws_mqtt_client_connection *connection = handler->impl;
647     (void)connection;
648 }
649 
s_message_overhead(struct aws_channel_handler * handler)650 static size_t s_message_overhead(struct aws_channel_handler *handler) {
651     (void)handler;
652     return 0;
653 }
654 
aws_mqtt_get_client_channel_vtable(void)655 struct aws_channel_handler_vtable *aws_mqtt_get_client_channel_vtable(void) {
656 
657     static struct aws_channel_handler_vtable s_vtable = {
658         .process_read_message = &s_process_read_message,
659         .process_write_message = NULL,
660         .increment_read_window = NULL,
661         .shutdown = &s_shutdown,
662         .initial_window_size = &s_initial_window_size,
663         .message_overhead = &s_message_overhead,
664         .destroy = &s_destroy,
665     };
666 
667     return &s_vtable;
668 }
669 
670 /*******************************************************************************
671  * Helpers
672  ******************************************************************************/
673 
mqtt_get_message_for_packet(struct aws_mqtt_client_connection * connection,struct aws_mqtt_fixed_header * header)674 struct aws_io_message *mqtt_get_message_for_packet(
675     struct aws_mqtt_client_connection *connection,
676     struct aws_mqtt_fixed_header *header) {
677 
678     const size_t required_length = 3 + header->remaining_length;
679 
680     struct aws_io_message *message = aws_channel_acquire_message_from_pool(
681         connection->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, required_length);
682 
683     AWS_LOGF_TRACE(
684         AWS_LS_MQTT_CLIENT,
685         "id=%p: Acquiring memory from pool of required_length %zu",
686         (void *)connection,
687         required_length);
688 
689     return message;
690 }
691 
692 /*******************************************************************************
693  * Requests
694  ******************************************************************************/
695 
696 /* Send the request */
s_request_outgoing_task(struct aws_channel_task * task,void * arg,enum aws_task_status status)697 static void s_request_outgoing_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) {
698 
699     struct aws_mqtt_request *request = arg;
700     struct aws_mqtt_client_connection *connection = request->connection;
701 
702     if (status == AWS_TASK_STATUS_CANCELED) {
703         /* Connection lost before the request ever get send, check the request needs to be retried or not */
704         if (request->retryable) {
705             AWS_LOGF_DEBUG(
706                 AWS_LS_MQTT_CLIENT,
707                 "static: task id %p, was canceled due to the channel shutting down. Request for packet id "
708                 "%" PRIu16 ". will be retried",
709                 (void *)task,
710                 request->packet_id);
711             /* put it into the offline queue. */
712             { /* BEGIN CRITICAL SECTION */
713                 mqtt_connection_lock_synced_data(connection);
714                 aws_linked_list_push_back(&connection->synced_data.pending_requests_list, &request->list_node);
715                 mqtt_connection_unlock_synced_data(connection);
716             } /* END CRITICAL SECTION */
717         } else {
718             AWS_LOGF_DEBUG(
719                 AWS_LS_MQTT_CLIENT,
720                 "static: task id %p, was canceled due to the channel shutting down. Request for packet id "
721                 "%" PRIu16 ". will NOT be retried, will be cancelled",
722                 (void *)task,
723                 request->packet_id);
724             /* Fire the callback and clean up the memory, as the connection get destroyed. */
725             if (request->on_complete) {
726                 request->on_complete(
727                     connection, request->packet_id, AWS_ERROR_MQTT_NOT_CONNECTED, request->on_complete_ud);
728             }
729             { /* BEGIN CRITICAL SECTION */
730                 mqtt_connection_lock_synced_data(connection);
731                 aws_hash_table_remove(
732                     &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
733                 aws_memory_pool_release(&connection->synced_data.requests_pool, request);
734                 mqtt_connection_unlock_synced_data(connection);
735             } /* END CRITICAL SECTION */
736         }
737         return;
738     }
739 
740     /* Send the request */
741     enum aws_mqtt_client_request_state state =
742         request->send_request(request->packet_id, !request->initiated, request->send_request_ud);
743     request->initiated = true;
744     int error_code = AWS_ERROR_SUCCESS;
745     switch (state) {
746         case AWS_MQTT_CLIENT_REQUEST_ERROR:
747             error_code = aws_last_error();
748             AWS_LOGF_ERROR(
749                 AWS_LS_MQTT_CLIENT,
750                 "id=%p: sending request %" PRIu16 " failed with error %d.",
751                 (void *)request->connection,
752                 request->packet_id,
753                 error_code);
754             /* fall-thru */
755 
756         case AWS_MQTT_CLIENT_REQUEST_COMPLETE:
757             AWS_LOGF_TRACE(
758                 AWS_LS_MQTT_CLIENT,
759                 "id=%p: sending request %" PRIu16 " complete, invoking on_complete callback.",
760                 (void *)request->connection,
761                 request->packet_id);
762             /* If the send_request function reports the request is complete,
763              * remove from the hash table and call the callback. */
764             if (request->on_complete) {
765                 request->on_complete(connection, request->packet_id, error_code, request->on_complete_ud);
766             }
767             { /* BEGIN CRITICAL SECTION */
768                 mqtt_connection_lock_synced_data(connection);
769                 aws_hash_table_remove(
770                     &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
771                 aws_memory_pool_release(&connection->synced_data.requests_pool, request);
772                 mqtt_connection_unlock_synced_data(connection);
773             } /* END CRITICAL SECTION */
774             break;
775 
776         case AWS_MQTT_CLIENT_REQUEST_ONGOING:
777             AWS_LOGF_TRACE(
778                 AWS_LS_MQTT_CLIENT,
779                 "id=%p: request %" PRIu16 " sent, but waiting on an acknowledgement from peer.",
780                 (void *)request->connection,
781                 request->packet_id);
782             /* Put the request into the ongoing list */
783             aws_linked_list_push_back(&connection->thread_data.ongoing_requests_list, &request->list_node);
784             break;
785     }
786 }
787 
mqtt_create_request(struct aws_mqtt_client_connection * connection,aws_mqtt_send_request_fn * send_request,void * send_request_ud,aws_mqtt_op_complete_fn * on_complete,void * on_complete_ud,bool noRetry)788 uint16_t mqtt_create_request(
789     struct aws_mqtt_client_connection *connection,
790     aws_mqtt_send_request_fn *send_request,
791     void *send_request_ud,
792     aws_mqtt_op_complete_fn *on_complete,
793     void *on_complete_ud,
794     bool noRetry) {
795 
796     AWS_ASSERT(connection);
797     AWS_ASSERT(send_request);
798     struct aws_mqtt_request *next_request = NULL;
799     bool should_schedule_task = false;
800     struct aws_channel *channel = NULL;
801     { /* BEGIN CRITICAL SECTION */
802         mqtt_connection_lock_synced_data(connection);
803         if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
804             mqtt_connection_unlock_synced_data(connection);
805             /* User requested disconnecting, ensure no new requests are made until the channel finished shutting
806              * down. */
807             AWS_LOGF_ERROR(
808                 AWS_LS_MQTT_CLIENT,
809                 "id=%p: Disconnect requested, stop creating any new request until disconnect process finishes.",
810                 (void *)connection);
811             aws_raise_error(AWS_ERROR_MQTT_CONNECTION_DISCONNECTING);
812             return 0;
813         }
814 
815         if (noRetry && connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED) {
816             mqtt_connection_unlock_synced_data(connection);
817             /* Not offline queueing QoS 0 publish or PINGREQ. Fail the call. */
818             AWS_LOGF_DEBUG(
819                 AWS_LS_MQTT_CLIENT,
820                 "id=%p: Not currently connected. No offline queueing for QoS 0 publish or pingreq.",
821                 (void *)connection);
822             aws_raise_error(AWS_ERROR_MQTT_NOT_CONNECTED);
823             return 0;
824         }
825         /**
826          * Find a free packet ID.
827          * QoS 0 PUBLISH packets don't actually need an ID on the wire,
828          * but we assign them internally anyway just so everything has a unique ID.
829          *
830          * Yes, this is an O(N) search.
831          * We remember the last ID we assigned, so it's O(1) in the common case.
832          * But it's theoretically possible to reach O(N) where N is just above 64000
833          * if the user is letting a ton of un-ack'd messages queue up
834          */
835         uint16_t search_start = connection->synced_data.packet_id;
836         struct aws_hash_element *elem = NULL;
837         while (true) {
838             /* Increment ID, watch out for overflow, ID cannot be 0 */
839             if (connection->synced_data.packet_id == UINT16_MAX) {
840                 connection->synced_data.packet_id = 1;
841             } else {
842                 connection->synced_data.packet_id++;
843             }
844 
845             /* Is there already an outstanding request using this ID? */
846             aws_hash_table_find(
847                 &connection->synced_data.outstanding_requests_table, &connection->synced_data.packet_id, &elem);
848 
849             if (elem == NULL) {
850                 /* Found a free ID! Break out of loop */
851                 break;
852             } else if (connection->synced_data.packet_id == search_start) {
853                 /* Every ID is taken */
854                 mqtt_connection_unlock_synced_data(connection);
855                 AWS_LOGF_ERROR(
856                     AWS_LS_MQTT_CLIENT,
857                     "id=%p: Queue is full. No more packet IDs are available at this time.",
858                     (void *)connection);
859                 aws_raise_error(AWS_ERROR_MQTT_QUEUE_FULL);
860                 return 0;
861             }
862         }
863 
864         next_request = aws_memory_pool_acquire(&connection->synced_data.requests_pool);
865         if (!next_request) {
866             mqtt_connection_unlock_synced_data(connection);
867             return 0;
868         }
869         memset(next_request, 0, sizeof(struct aws_mqtt_request));
870 
871         next_request->packet_id = connection->synced_data.packet_id;
872 
873         if (aws_hash_table_put(
874                 &connection->synced_data.outstanding_requests_table, &next_request->packet_id, next_request, NULL)) {
875             /* failed to put the next request into the table */
876             aws_memory_pool_release(&connection->synced_data.requests_pool, next_request);
877             mqtt_connection_unlock_synced_data(connection);
878             return 0;
879         }
880         /* Store the request by packet_id */
881         next_request->allocator = connection->allocator;
882         next_request->connection = connection;
883         next_request->initiated = false;
884         next_request->retryable = !noRetry;
885         next_request->send_request = send_request;
886         next_request->send_request_ud = send_request_ud;
887         next_request->on_complete = on_complete;
888         next_request->on_complete_ud = on_complete_ud;
889         aws_channel_task_init(
890             &next_request->outgoing_task, s_request_outgoing_task, next_request, "mqtt_outgoing_request_task");
891         if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED) {
892             aws_linked_list_push_back(&connection->synced_data.pending_requests_list, &next_request->list_node);
893         } else {
894             AWS_ASSERT(connection->slot);
895             AWS_ASSERT(connection->slot->channel);
896             should_schedule_task = true;
897             channel = connection->slot->channel;
898             /* keep the channel alive until the task is scheduled */
899             aws_channel_acquire_hold(channel);
900         }
901         mqtt_connection_unlock_synced_data(connection);
902     } /* END CRITICAL SECTION */
903     if (should_schedule_task) {
904         AWS_LOGF_TRACE(
905             AWS_LS_MQTT_CLIENT,
906             "id=%p: Currently not in the event-loop thread, scheduling a task to send message id %" PRIu16 ".",
907             (void *)connection,
908             next_request->packet_id);
909         aws_channel_schedule_task_now(channel, &next_request->outgoing_task);
910         /* release the refcount we hold with the protection of lock */
911         aws_channel_release_hold(channel);
912     }
913 
914     return next_request->packet_id;
915 }
916 
mqtt_request_complete(struct aws_mqtt_client_connection * connection,int error_code,uint16_t packet_id)917 void mqtt_request_complete(struct aws_mqtt_client_connection *connection, int error_code, uint16_t packet_id) {
918 
919     AWS_LOGF_TRACE(
920         AWS_LS_MQTT_CLIENT,
921         "id=%p: message id %" PRIu16 " completed with error code %d, removing from outstanding requests list.",
922         (void *)connection,
923         packet_id,
924         error_code);
925 
926     bool found_request = false;
927     aws_mqtt_op_complete_fn *on_complete = NULL;
928     void *on_complete_ud = NULL;
929 
930     { /* BEGIN CRITICAL SECTION */
931         mqtt_connection_lock_synced_data(connection);
932         struct aws_hash_element *elem = NULL;
933         aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &packet_id, &elem);
934         if (elem != NULL) {
935             found_request = true;
936 
937             struct aws_mqtt_request *request = elem->value;
938             on_complete = request->on_complete;
939             on_complete_ud = request->on_complete_ud;
940 
941             /* clean up request resources */
942             aws_hash_table_remove_element(&connection->synced_data.outstanding_requests_table, elem);
943             /* remove the request from the list, which is thread_data.ongoing_requests_list */
944             aws_linked_list_remove(&request->list_node);
945             aws_memory_pool_release(&connection->synced_data.requests_pool, request);
946         }
947         mqtt_connection_unlock_synced_data(connection);
948     } /* END CRITICAL SECTION */
949 
950     if (!found_request) {
951         AWS_LOGF_DEBUG(
952             AWS_LS_MQTT_CLIENT,
953             "id=%p: received completion for message id %" PRIu16
954             " but no outstanding request exists.  Assuming this is an ack of a resend when the first request has "
955             "already completed.",
956             (void *)connection,
957             packet_id);
958         return;
959     }
960 
961     /* Invoke the complete callback. */
962     if (on_complete) {
963         on_complete(connection, packet_id, error_code, on_complete_ud);
964     }
965 }
966 
967 struct mqtt_shutdown_task {
968     int error_code;
969     struct aws_channel_task task;
970 };
971 
s_mqtt_disconnect_task(struct aws_channel_task * channel_task,void * arg,enum aws_task_status status)972 static void s_mqtt_disconnect_task(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
973 
974     (void)status;
975 
976     struct mqtt_shutdown_task *task = AWS_CONTAINER_OF(channel_task, struct mqtt_shutdown_task, task);
977     struct aws_mqtt_client_connection *connection = arg;
978 
979     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Doing disconnect", (void *)connection);
980     { /* BEGIN CRITICAL SECTION */
981         mqtt_connection_lock_synced_data(connection);
982         /* If there is an outstanding reconnect task, cancel it */
983         if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING && connection->reconnect_task) {
984             aws_atomic_store_ptr(&connection->reconnect_task->connection_ptr, NULL);
985             /* If the reconnect_task isn't scheduled, free it */
986             if (connection->reconnect_task && !connection->reconnect_task->task.timestamp) {
987                 aws_mem_release(connection->reconnect_task->allocator, connection->reconnect_task);
988             }
989             connection->reconnect_task = NULL;
990         }
991         mqtt_connection_unlock_synced_data(connection);
992     } /* END CRITICAL SECTION */
993 
994     if (connection->slot && connection->slot->channel) {
995         aws_channel_shutdown(connection->slot->channel, task->error_code);
996     }
997 
998     aws_mem_release(connection->allocator, task);
999 }
1000 
mqtt_disconnect_impl(struct aws_mqtt_client_connection * connection,int error_code)1001 void mqtt_disconnect_impl(struct aws_mqtt_client_connection *connection, int error_code) {
1002     if (connection->slot) {
1003         struct mqtt_shutdown_task *shutdown_task =
1004             aws_mem_calloc(connection->allocator, 1, sizeof(struct mqtt_shutdown_task));
1005         shutdown_task->error_code = error_code;
1006         aws_channel_task_init(&shutdown_task->task, s_mqtt_disconnect_task, connection, "mqtt_disconnect");
1007         aws_channel_schedule_task_now(connection->slot->channel, &shutdown_task->task);
1008     }
1009 }
1010