1 /**
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  * SPDX-License-Identifier: Apache-2.0.
4  */
5 #include <aws/mqtt/client.h>
6 
7 #include <aws/mqtt/private/client_impl.h>
8 #include <aws/mqtt/private/mqtt_client_test_helper.h>
9 #include <aws/mqtt/private/packets.h>
10 #include <aws/mqtt/private/topic_tree.h>
11 
12 #include <aws/http/proxy.h>
13 
14 #include <aws/io/channel_bootstrap.h>
15 #include <aws/io/event_loop.h>
16 #include <aws/io/logging.h>
17 #include <aws/io/socket.h>
18 #include <aws/io/tls_channel_handler.h>
19 #include <aws/io/uri.h>
20 
21 #include <aws/common/clock.h>
22 #include <aws/common/task_scheduler.h>
23 
24 #include <inttypes.h>
25 
26 #ifdef AWS_MQTT_WITH_WEBSOCKETS
27 #    include <aws/http/connection.h>
28 #    include <aws/http/request_response.h>
29 #    include <aws/http/websocket.h>
30 #endif
31 
32 #ifdef _MSC_VER
33 #    pragma warning(disable : 4204)
34 #endif
35 
36 /* 3 seconds */
37 static const uint64_t s_default_ping_timeout_ns = 3000000000;
38 
39 /* 20 minutes - This is the default (and max) for AWS IoT as of 2020.02.18 */
40 static const uint16_t s_default_keep_alive_sec = 1200;
41 
42 static int s_mqtt_client_connect(
43     struct aws_mqtt_client_connection *connection,
44     aws_mqtt_client_on_connection_complete_fn *on_connection_complete,
45     void *userdata);
46 /*******************************************************************************
47  * Helper functions
48  ******************************************************************************/
49 
mqtt_connection_lock_synced_data(struct aws_mqtt_client_connection * connection)50 void mqtt_connection_lock_synced_data(struct aws_mqtt_client_connection *connection) {
51     int err = aws_mutex_lock(&connection->synced_data.lock);
52     AWS_ASSERT(!err);
53     (void)err;
54 }
55 
mqtt_connection_unlock_synced_data(struct aws_mqtt_client_connection * connection)56 void mqtt_connection_unlock_synced_data(struct aws_mqtt_client_connection *connection) {
57     ASSERT_SYNCED_DATA_LOCK_HELD(connection);
58 
59     int err = aws_mutex_unlock(&connection->synced_data.lock);
60     AWS_ASSERT(!err);
61     (void)err;
62 }
63 
s_aws_mqtt_client_destroy(struct aws_mqtt_client * client)64 static void s_aws_mqtt_client_destroy(struct aws_mqtt_client *client) {
65 
66     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "client=%p: Cleaning up MQTT client", (void *)client);
67     aws_client_bootstrap_release(client->bootstrap);
68 
69     aws_mem_release(client->allocator, client);
70 }
71 
mqtt_connection_set_state(struct aws_mqtt_client_connection * connection,enum aws_mqtt_client_connection_state state)72 void mqtt_connection_set_state(
73     struct aws_mqtt_client_connection *connection,
74     enum aws_mqtt_client_connection_state state) {
75     ASSERT_SYNCED_DATA_LOCK_HELD(connection);
76     if (connection->synced_data.state == state) {
77         AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: MQTT connection already in state %d", (void *)connection, state);
78         return;
79     }
80     connection->synced_data.state = state;
81 }
82 
83 struct request_timeout_wrapper;
84 
85 /* used for timeout task */
86 struct request_timeout_task_arg {
87     uint16_t packet_id;
88     struct aws_mqtt_client_connection *connection;
89     struct request_timeout_wrapper *task_arg_wrapper;
90 };
91 
92 /*
93  * We want the timeout task to be able to destroy the forward reference from the operation's task arg structure
94  * to the timeout task.  But the operation task arg structures don't have any data structure in common.  So to allow
95  * the timeout to refer back to a zero-able forward pointer, we wrap a pointer to the timeout task and embed it
96  * in every operation's task arg that needs to create a timeout.
97  */
98 struct request_timeout_wrapper {
99     struct request_timeout_task_arg *timeout_task_arg;
100 };
101 
s_request_timeout(struct aws_channel_task * channel_task,void * arg,enum aws_task_status status)102 static void s_request_timeout(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
103     (void)channel_task;
104     struct request_timeout_task_arg *timeout_task_arg = arg;
105     struct aws_mqtt_client_connection *connection = timeout_task_arg->connection;
106 
107     if (status == AWS_TASK_STATUS_RUN_READY) {
108         if (timeout_task_arg->task_arg_wrapper != NULL) {
109             mqtt_request_complete(connection, AWS_ERROR_MQTT_TIMEOUT, timeout_task_arg->packet_id);
110         }
111     }
112 
113     /*
114      * Whether cancelled or run, if we have a back pointer to the operation's task arg, we must zero it out
115      * so that when it completes it does not try to cancel us, because we will already be freed.
116      *
117      * If we don't have a back pointer to the operation's task arg, that means it already ran and completed.
118      */
119     if (timeout_task_arg->task_arg_wrapper != NULL) {
120         timeout_task_arg->task_arg_wrapper->timeout_task_arg = NULL;
121         timeout_task_arg->task_arg_wrapper = NULL;
122     }
123 
124     aws_mem_release(connection->allocator, timeout_task_arg);
125 }
126 
s_schedule_timeout_task(struct aws_mqtt_client_connection * connection,uint16_t packet_id)127 static struct request_timeout_task_arg *s_schedule_timeout_task(
128     struct aws_mqtt_client_connection *connection,
129     uint16_t packet_id) {
130     /* schedule a timeout task to run, in case server consider the publish is not received */
131     struct aws_channel_task *request_timeout_task = NULL;
132     struct request_timeout_task_arg *timeout_task_arg = NULL;
133     if (!aws_mem_acquire_many(
134             connection->allocator,
135             2,
136             &timeout_task_arg,
137             sizeof(struct request_timeout_task_arg),
138             &request_timeout_task,
139             sizeof(struct aws_channel_task))) {
140         return NULL;
141     }
142     aws_channel_task_init(request_timeout_task, s_request_timeout, timeout_task_arg, "mqtt_request_timeout");
143     AWS_ZERO_STRUCT(*timeout_task_arg);
144     timeout_task_arg->connection = connection;
145     timeout_task_arg->packet_id = packet_id;
146     uint64_t timestamp = 0;
147     if (aws_channel_current_clock_time(connection->slot->channel, &timestamp)) {
148         aws_mem_release(connection->allocator, timeout_task_arg);
149         return NULL;
150     }
151     timestamp = aws_add_u64_saturating(timestamp, connection->operation_timeout_ns);
152     aws_channel_schedule_task_future(connection->slot->channel, request_timeout_task, timestamp);
153     return timeout_task_arg;
154 }
155 
156 /*******************************************************************************
157  * Client Init
158  ******************************************************************************/
aws_mqtt_client_new(struct aws_allocator * allocator,struct aws_client_bootstrap * bootstrap)159 struct aws_mqtt_client *aws_mqtt_client_new(struct aws_allocator *allocator, struct aws_client_bootstrap *bootstrap) {
160 
161     aws_mqtt_fatal_assert_library_initialized();
162 
163     struct aws_mqtt_client *client = aws_mem_calloc(allocator, 1, sizeof(struct aws_mqtt_client));
164     if (client == NULL) {
165         return NULL;
166     }
167 
168     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "client=%p: Initalizing MQTT client", (void *)client);
169 
170     client->allocator = allocator;
171     client->bootstrap = aws_client_bootstrap_acquire(bootstrap);
172     aws_ref_count_init(&client->ref_count, client, (aws_simple_completion_callback *)s_aws_mqtt_client_destroy);
173 
174     return client;
175 }
176 
aws_mqtt_client_acquire(struct aws_mqtt_client * client)177 struct aws_mqtt_client *aws_mqtt_client_acquire(struct aws_mqtt_client *client) {
178     if (client != NULL) {
179         aws_ref_count_acquire(&client->ref_count);
180     }
181 
182     return client;
183 }
184 
aws_mqtt_client_release(struct aws_mqtt_client * client)185 void aws_mqtt_client_release(struct aws_mqtt_client *client) {
186     if (client != NULL) {
187         aws_ref_count_release(&client->ref_count);
188     }
189 }
190 
191 /* At this point, the channel for the MQTT connection has completed its shutdown */
s_mqtt_client_shutdown(struct aws_client_bootstrap * bootstrap,int error_code,struct aws_channel * channel,void * user_data)192 static void s_mqtt_client_shutdown(
193     struct aws_client_bootstrap *bootstrap,
194     int error_code,
195     struct aws_channel *channel,
196     void *user_data) {
197 
198     (void)bootstrap;
199     (void)channel;
200 
201     struct aws_mqtt_client_connection *connection = user_data;
202 
203     AWS_LOGF_TRACE(
204         AWS_LS_MQTT_CLIENT, "id=%p: Channel has been shutdown with error code %d", (void *)connection, error_code);
205     enum aws_mqtt_client_connection_state prev_state;
206     struct aws_linked_list cancelling_requests;
207     aws_linked_list_init(&cancelling_requests);
208     bool disconnected_state = false;
209     { /* BEGIN CRITICAL SECTION */
210         mqtt_connection_lock_synced_data(connection);
211         /* Move all the ongoing requests to the pending requests list, because the response they are waiting for will
212          * never arrives. Sad. But, we will retry. */
213         if (connection->clean_session) {
214             /* For a clean session, the Session lasts as long as the Network Connection. Thus, discard the previous
215              * session */
216             AWS_LOGF_TRACE(
217                 AWS_LS_MQTT_CLIENT,
218                 "id=%p: Discard ongoing requests and pending requests when a clean session connection lost.",
219                 (void *)connection);
220             aws_linked_list_move_all_back(&cancelling_requests, &connection->thread_data.ongoing_requests_list);
221             aws_linked_list_move_all_back(&cancelling_requests, &connection->synced_data.pending_requests_list);
222         } else {
223             aws_linked_list_move_all_back(
224                 &connection->synced_data.pending_requests_list, &connection->thread_data.ongoing_requests_list);
225             AWS_LOGF_TRACE(
226                 AWS_LS_MQTT_CLIENT,
227                 "id=%p: All subscribe/unsubscribe and publish QoS>0 have been move to pending list",
228                 (void *)connection);
229         }
230         prev_state = connection->synced_data.state;
231         switch (connection->synced_data.state) {
232             case AWS_MQTT_CLIENT_STATE_CONNECTED:
233                 /* unexpected hangup from broker, try to reconnect */
234                 mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_RECONNECTING);
235                 AWS_LOGF_DEBUG(
236                     AWS_LS_MQTT_CLIENT,
237                     "id=%p: connection was unexpected interrupted, switch state to RECONNECTING.",
238                     (void *)connection);
239                 break;
240             case AWS_MQTT_CLIENT_STATE_DISCONNECTING:
241                 /* disconnect requested by user */
242                 /* Successfully shutdown, so clear the outstanding requests */
243                 /* TODO: respect the cleansession, clear the table when needed */
244                 aws_hash_table_clear(&connection->synced_data.outstanding_requests_table);
245                 disconnected_state = true;
246                 AWS_LOGF_DEBUG(
247                     AWS_LS_MQTT_CLIENT,
248                     "id=%p: disconnect finished, switch state to DISCONNECTED.",
249                     (void *)connection);
250                 break;
251             case AWS_MQTT_CLIENT_STATE_CONNECTING:
252                 /* failed to connect */
253                 disconnected_state = true;
254                 break;
255             case AWS_MQTT_CLIENT_STATE_RECONNECTING:
256                 /* reconnect failed, schedule the next attempt later, no need to change the state. */
257                 break;
258             default:
259                 /* AWS_MQTT_CLIENT_STATE_DISCONNECTED */
260                 break;
261         }
262         AWS_LOGF_TRACE(
263             AWS_LS_MQTT_CLIENT, "id=%p: current state is %d", (void *)connection, (int)connection->synced_data.state);
264         /* Always clear slot, as that's what's been shutdown */
265         if (connection->slot) {
266             aws_channel_slot_remove(connection->slot);
267             AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: slot is removed successfully", (void *)connection);
268             connection->slot = NULL;
269         }
270 
271         mqtt_connection_unlock_synced_data(connection);
272     } /* END CRITICAL SECTION */
273 
274     if (!aws_linked_list_empty(&cancelling_requests)) {
275         struct aws_linked_list_node *current = aws_linked_list_front(&cancelling_requests);
276         const struct aws_linked_list_node *end = aws_linked_list_end(&cancelling_requests);
277         while (current != end) {
278             struct aws_mqtt_request *request = AWS_CONTAINER_OF(current, struct aws_mqtt_request, list_node);
279             if (request->on_complete) {
280                 request->on_complete(
281                     connection,
282                     request->packet_id,
283                     AWS_ERROR_MQTT_CANCELLED_FOR_CLEAN_SESSION,
284                     request->on_complete_ud);
285             }
286             current = current->next;
287         }
288         { /* BEGIN CRITICAL SECTION */
289             mqtt_connection_lock_synced_data(connection);
290             while (!aws_linked_list_empty(&cancelling_requests)) {
291                 struct aws_linked_list_node *node = aws_linked_list_pop_front(&cancelling_requests);
292                 struct aws_mqtt_request *request = AWS_CONTAINER_OF(node, struct aws_mqtt_request, list_node);
293                 aws_hash_table_remove(
294                     &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
295                 aws_memory_pool_release(&connection->synced_data.requests_pool, request);
296             }
297             mqtt_connection_unlock_synced_data(connection);
298         } /* END CRITICAL SECTION */
299     }
300 
301     /* If there's no error code and this wasn't user-requested, set the error code to something useful */
302     if (error_code == AWS_ERROR_SUCCESS) {
303         if (prev_state != AWS_MQTT_CLIENT_STATE_DISCONNECTING && prev_state != AWS_MQTT_CLIENT_STATE_DISCONNECTED) {
304             error_code = AWS_ERROR_MQTT_UNEXPECTED_HANGUP;
305         }
306     }
307     switch (prev_state) {
308         case AWS_MQTT_CLIENT_STATE_RECONNECTING: {
309             /* If reconnect attempt failed, schedule the next attempt */
310             struct aws_event_loop *el =
311                 aws_event_loop_group_get_next_loop(connection->client->bootstrap->event_loop_group);
312 
313             AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Reconnect failed, retrying", (void *)connection);
314 
315             aws_event_loop_schedule_task_future(
316                 el, &connection->reconnect_task->task, connection->reconnect_timeouts.next_attempt);
317             break;
318         }
319         case AWS_MQTT_CLIENT_STATE_CONNECTED: {
320             AWS_LOGF_DEBUG(
321                 AWS_LS_MQTT_CLIENT,
322                 "id=%p: Connection interrupted, calling callback and attempting reconnect",
323                 (void *)connection);
324             MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_interrupted, error_code);
325 
326             /* In case user called disconnect from the on_interrupted callback */
327             bool stop_reconnect;
328             { /* BEGIN CRITICAL SECTION */
329                 mqtt_connection_lock_synced_data(connection);
330                 stop_reconnect = connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING;
331                 if (stop_reconnect) {
332                     disconnected_state = true;
333                     AWS_LOGF_DEBUG(
334                         AWS_LS_MQTT_CLIENT,
335                         "id=%p: disconnect finished, switch state to DISCONNECTED.",
336                         (void *)connection);
337                 }
338                 mqtt_connection_unlock_synced_data(connection);
339             } /* END CRITICAL SECTION */
340             if (!stop_reconnect) {
341                 /* Attempt the reconnect immediately, which will schedule a task to retry if it doesn't succeed */
342                 connection->reconnect_task->task.fn(
343                     &connection->reconnect_task->task, connection->reconnect_task->task.arg, AWS_TASK_STATUS_RUN_READY);
344             }
345             break;
346         }
347         default:
348             break;
349     }
350     if (disconnected_state) {
351         { /* BEGIN CRITICAL SECTION */
352             mqtt_connection_lock_synced_data(connection);
353             mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_DISCONNECTED);
354             mqtt_connection_unlock_synced_data(connection);
355         } /* END CRITICAL SECTION */
356         switch (prev_state) {
357             case AWS_MQTT_CLIENT_STATE_CONNECTED:
358                 AWS_LOGF_TRACE(
359                     AWS_LS_MQTT_CLIENT,
360                     "id=%p: Caller requested disconnect from on_interrupted callback, aborting reconnect",
361                     (void *)connection);
362                 MQTT_CLIENT_CALL_CALLBACK(connection, on_disconnect);
363                 break;
364             case AWS_MQTT_CLIENT_STATE_DISCONNECTING:
365                 AWS_LOGF_DEBUG(
366                     AWS_LS_MQTT_CLIENT,
367                     "id=%p: Disconnect completed, clearing request queue and calling callback",
368                     (void *)connection);
369                 MQTT_CLIENT_CALL_CALLBACK(connection, on_disconnect);
370                 break;
371             case AWS_MQTT_CLIENT_STATE_CONNECTING:
372                 AWS_LOGF_TRACE(
373                     AWS_LS_MQTT_CLIENT,
374                     "id=%p: Initial connection attempt failed, calling callback",
375                     (void *)connection);
376                 MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_connection_complete, error_code, 0, false);
377                 break;
378             default:
379                 break;
380         }
381         /* The connection can die now. Release the refcount */
382         aws_mqtt_client_connection_release(connection);
383     }
384 }
385 
386 /*******************************************************************************
387  * Connection New
388  ******************************************************************************/
389 /* The assumption here is that a connection always outlives its channels, and the channel this task was scheduled on
390  * always outlives this task, so all we need to do is check the connection state. If we are in a state that waits
391  * for a CONNACK, kill it off. In the case that the connection died between scheduling this task and it being executed
392  * the status will always be CANCELED because this task will be canceled when the owning channel goes away. */
s_connack_received_timeout(struct aws_channel_task * channel_task,void * arg,enum aws_task_status status)393 static void s_connack_received_timeout(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
394     struct aws_mqtt_client_connection *connection = arg;
395 
396     if (status == AWS_TASK_STATUS_RUN_READY) {
397         bool time_out = false;
398         { /* BEGIN CRITICAL SECTION */
399             mqtt_connection_lock_synced_data(connection);
400             time_out =
401                 (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_CONNECTING ||
402                  connection->synced_data.state == AWS_MQTT_CLIENT_STATE_RECONNECTING);
403             mqtt_connection_unlock_synced_data(connection);
404         } /* END CRITICAL SECTION */
405         if (time_out) {
406             AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: mqtt CONNACK response timeout detected", (void *)connection);
407             aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_TIMEOUT);
408         }
409     }
410 
411     aws_mem_release(connection->allocator, channel_task);
412 }
413 
414 /**
415  * Channel has been initialized callback. Sets up channel handler and sends out CONNECT packet.
416  * The on_connack callback is called with the CONNACK packet is received from the server.
417  */
s_mqtt_client_init(struct aws_client_bootstrap * bootstrap,int error_code,struct aws_channel * channel,void * user_data)418 static void s_mqtt_client_init(
419     struct aws_client_bootstrap *bootstrap,
420     int error_code,
421     struct aws_channel *channel,
422     void *user_data) {
423 
424     (void)bootstrap;
425     struct aws_io_message *message = NULL;
426 
427     /* Setup callback contract is: if error_code is non-zero then channel is NULL. */
428     AWS_FATAL_ASSERT((error_code != 0) == (channel == NULL));
429 
430     struct aws_mqtt_client_connection *connection = user_data;
431 
432     if (error_code != AWS_OP_SUCCESS) {
433         /* client shutdown already handles this case, so just call that. */
434         s_mqtt_client_shutdown(bootstrap, error_code, channel, user_data);
435         return;
436     }
437 
438     /* user requested disconnect before the channel has been set up. Stop installing the slot and sending CONNECT. */
439     bool failed_create_slot = false;
440 
441     { /* BEGIN CRITICAL SECTION */
442         mqtt_connection_lock_synced_data(connection);
443 
444         if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
445             /* It only happens when the user request disconnect during reconnecting, we don't need to fire any callback.
446              * The on_disconnect will be invoked as channel finish shutting down. */
447             mqtt_connection_unlock_synced_data(connection);
448             aws_channel_shutdown(channel, AWS_ERROR_SUCCESS);
449             return;
450         }
451         /* Create the slot */
452         connection->slot = aws_channel_slot_new(channel);
453         if (!connection->slot) {
454             failed_create_slot = true;
455         }
456         mqtt_connection_unlock_synced_data(connection);
457     } /* END CRITICAL SECTION */
458 
459     /* intall the slot and handler */
460     if (failed_create_slot) {
461 
462         AWS_LOGF_ERROR(
463             AWS_LS_MQTT_CLIENT,
464             "id=%p: Failed to create new slot, something has gone horribly wrong, error %d (%s).",
465             (void *)connection,
466             aws_last_error(),
467             aws_error_name(aws_last_error()));
468         goto handle_error;
469     }
470 
471     if (aws_channel_slot_insert_end(channel, connection->slot)) {
472         AWS_LOGF_ERROR(
473             AWS_LS_MQTT_CLIENT,
474             "id=%p: Failed to insert slot into channel %p, error %d (%s).",
475             (void *)connection,
476             (void *)channel,
477             aws_last_error(),
478             aws_error_name(aws_last_error()));
479         goto handle_error;
480     }
481 
482     if (aws_channel_slot_set_handler(connection->slot, &connection->handler)) {
483         AWS_LOGF_ERROR(
484             AWS_LS_MQTT_CLIENT,
485             "id=%p: Failed to set MQTT handler into slot on channel %p, error %d (%s).",
486             (void *)connection,
487             (void *)channel,
488             aws_last_error(),
489             aws_error_name(aws_last_error()));
490 
491         goto handle_error;
492     }
493 
494     AWS_LOGF_DEBUG(
495         AWS_LS_MQTT_CLIENT, "id=%p: Connection successfully opened, sending CONNECT packet", (void *)connection);
496 
497     struct aws_channel_task *connack_task = aws_mem_calloc(connection->allocator, 1, sizeof(struct aws_channel_task));
498     if (!connack_task) {
499         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to allocate timeout task.", (void *)connection);
500         goto handle_error;
501     }
502 
503     aws_channel_task_init(connack_task, s_connack_received_timeout, connection, "mqtt_connack_timeout");
504 
505     uint64_t now = 0;
506     if (aws_channel_current_clock_time(channel, &now)) {
507         AWS_LOGF_ERROR(
508             AWS_LS_MQTT_CLIENT,
509             "static: Failed to setting MQTT handler into slot on channel %p, error %d (%s).",
510             (void *)channel,
511             aws_last_error(),
512             aws_error_name(aws_last_error()));
513 
514         goto handle_error;
515     }
516     now += connection->ping_timeout_ns;
517     aws_channel_schedule_task_future(channel, connack_task, now);
518 
519     /* Send the connect packet */
520     struct aws_mqtt_packet_connect connect;
521     aws_mqtt_packet_connect_init(
522         &connect,
523         aws_byte_cursor_from_buf(&connection->client_id),
524         connection->clean_session,
525         connection->keep_alive_time_secs);
526 
527     if (connection->will.topic.buffer) {
528         /* Add will if present */
529 
530         struct aws_byte_cursor topic_cur = aws_byte_cursor_from_buf(&connection->will.topic);
531         struct aws_byte_cursor payload_cur = aws_byte_cursor_from_buf(&connection->will.payload);
532 
533         AWS_LOGF_DEBUG(
534             AWS_LS_MQTT_CLIENT,
535             "id=%p: Adding will to connection on " PRInSTR " with payload " PRInSTR,
536             (void *)connection,
537             AWS_BYTE_CURSOR_PRI(topic_cur),
538             AWS_BYTE_CURSOR_PRI(payload_cur));
539         aws_mqtt_packet_connect_add_will(
540             &connect, topic_cur, connection->will.qos, connection->will.retain, payload_cur);
541     }
542 
543     if (connection->username) {
544         struct aws_byte_cursor username_cur = aws_byte_cursor_from_string(connection->username);
545 
546         AWS_LOGF_DEBUG(
547             AWS_LS_MQTT_CLIENT,
548             "id=%p: Adding username " PRInSTR " to connection",
549             (void *)connection,
550             AWS_BYTE_CURSOR_PRI(username_cur))
551 
552         struct aws_byte_cursor password_cur = {
553             .ptr = NULL,
554             .len = 0,
555         };
556 
557         if (connection->password) {
558             password_cur = aws_byte_cursor_from_string(connection->password);
559         }
560 
561         aws_mqtt_packet_connect_add_credentials(&connect, username_cur, password_cur);
562     }
563 
564     message = mqtt_get_message_for_packet(connection, &connect.fixed_header);
565     if (!message) {
566 
567         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to get message from pool", (void *)connection);
568         goto handle_error;
569     }
570 
571     if (aws_mqtt_packet_connect_encode(&message->message_data, &connect)) {
572 
573         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to encode CONNECT packet", (void *)connection);
574         goto handle_error;
575     }
576 
577     if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
578 
579         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to send encoded CONNECT packet upstream", (void *)connection);
580         goto handle_error;
581     }
582 
583     return;
584 
585 handle_error:
586     MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_connection_complete, aws_last_error(), 0, false);
587     aws_channel_shutdown(channel, aws_last_error());
588 
589     if (message) {
590         aws_mem_release(message->allocator, message);
591     }
592 }
593 
s_attempt_reconnect(struct aws_task * task,void * userdata,enum aws_task_status status)594 static void s_attempt_reconnect(struct aws_task *task, void *userdata, enum aws_task_status status) {
595 
596     (void)task;
597 
598     struct aws_mqtt_reconnect_task *reconnect = userdata;
599     struct aws_mqtt_client_connection *connection = aws_atomic_load_ptr(&reconnect->connection_ptr);
600 
601     if (status == AWS_TASK_STATUS_RUN_READY && connection) {
602         /* If the task is not cancelled and a connection has not succeeded, attempt reconnect */
603 
604         aws_high_res_clock_get_ticks(&connection->reconnect_timeouts.next_attempt);
605         connection->reconnect_timeouts.next_attempt += aws_timestamp_convert(
606             connection->reconnect_timeouts.current, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL);
607 
608         AWS_LOGF_TRACE(
609             AWS_LS_MQTT_CLIENT,
610             "id=%p: Attempting reconnect, if it fails next attempt will be in %" PRIu64 " seconds",
611             (void *)connection,
612             connection->reconnect_timeouts.current);
613 
614         /* Check before multiplying to avoid potential overflow */
615         if (connection->reconnect_timeouts.current > connection->reconnect_timeouts.max / 2) {
616             connection->reconnect_timeouts.current = connection->reconnect_timeouts.max;
617         } else {
618             connection->reconnect_timeouts.current *= 2;
619         }
620 
621         if (s_mqtt_client_connect(
622                 connection, connection->on_connection_complete, connection->on_connection_complete_ud)) {
623 
624             /* If reconnect attempt failed, schedule the next attempt */
625             struct aws_event_loop *el =
626                 aws_event_loop_group_get_next_loop(connection->client->bootstrap->event_loop_group);
627             aws_event_loop_schedule_task_future(
628                 el, &connection->reconnect_task->task, connection->reconnect_timeouts.next_attempt);
629             AWS_LOGF_TRACE(
630                 AWS_LS_MQTT_CLIENT,
631                 "id=%p: Scheduling reconnect, for %" PRIu64 " on event-loop %p",
632                 (void *)connection,
633                 connection->reconnect_timeouts.next_attempt,
634                 (void *)el);
635         } else {
636             connection->reconnect_task->task.timestamp = 0;
637         }
638     } else {
639         aws_mem_release(reconnect->allocator, reconnect);
640     }
641 }
642 
aws_create_reconnect_task(struct aws_mqtt_client_connection * connection)643 void aws_create_reconnect_task(struct aws_mqtt_client_connection *connection) {
644     if (connection->reconnect_task == NULL) {
645         connection->reconnect_task = aws_mem_calloc(connection->allocator, 1, sizeof(struct aws_mqtt_reconnect_task));
646         AWS_FATAL_ASSERT(connection->reconnect_task != NULL);
647 
648         aws_atomic_init_ptr(&connection->reconnect_task->connection_ptr, connection);
649         connection->reconnect_task->allocator = connection->allocator;
650         aws_task_init(
651             &connection->reconnect_task->task, s_attempt_reconnect, connection->reconnect_task, "mqtt_reconnect");
652     }
653 }
654 
s_hash_uint16_t(const void * item)655 static uint64_t s_hash_uint16_t(const void *item) {
656     return *(uint16_t *)item;
657 }
658 
s_uint16_t_eq(const void * a,const void * b)659 static bool s_uint16_t_eq(const void *a, const void *b) {
660     return *(uint16_t *)a == *(uint16_t *)b;
661 }
662 
s_mqtt_client_connection_destroy_final(struct aws_mqtt_client_connection * connection)663 static void s_mqtt_client_connection_destroy_final(struct aws_mqtt_client_connection *connection) {
664     AWS_PRECONDITION(!connection || connection->allocator);
665     if (!connection) {
666         return;
667     }
668 
669     /* If the slot is not NULL, the connection is still connected, which should be prevented from calling this function
670      */
671     AWS_ASSERT(!connection->slot);
672 
673     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Destroying connection", (void *)connection);
674 
675     /* If the reconnect_task isn't freed, free it */
676     if (connection->reconnect_task) {
677         aws_mem_release(connection->reconnect_task->allocator, connection->reconnect_task);
678     }
679     aws_string_destroy(connection->host_name);
680 
681     /* Clear the credentials */
682     if (connection->username) {
683         aws_string_destroy_secure(connection->username);
684     }
685     if (connection->password) {
686         aws_string_destroy_secure(connection->password);
687     }
688 
689     /* Clean up the will */
690     aws_byte_buf_clean_up(&connection->will.topic);
691     aws_byte_buf_clean_up(&connection->will.payload);
692 
693     /* Clear the client_id */
694     aws_byte_buf_clean_up(&connection->client_id);
695 
696     /* Free all of the active subscriptions */
697     aws_mqtt_topic_tree_clean_up(&connection->thread_data.subscriptions);
698 
699     aws_hash_table_clean_up(&connection->synced_data.outstanding_requests_table);
700     /* clean up the pending_requests if it's not empty */
701     while (!aws_linked_list_empty(&connection->synced_data.pending_requests_list)) {
702         struct aws_linked_list_node *node = aws_linked_list_pop_front(&connection->synced_data.pending_requests_list);
703         struct aws_mqtt_request *request = AWS_CONTAINER_OF(node, struct aws_mqtt_request, list_node);
704         /* Fire the callback and clean up the memory, as the connection get destroyed. */
705         if (request->on_complete) {
706             request->on_complete(
707                 connection, request->packet_id, AWS_ERROR_MQTT_CONNECTION_DESTROYED, request->on_complete_ud);
708         }
709         aws_memory_pool_release(&connection->synced_data.requests_pool, request);
710     }
711     aws_memory_pool_clean_up(&connection->synced_data.requests_pool);
712 
713     aws_mutex_clean_up(&connection->synced_data.lock);
714 
715     aws_tls_connection_options_clean_up(&connection->tls_options);
716 
717     /* Clean up the websocket proxy options */
718     if (connection->http_proxy_config) {
719         aws_http_proxy_config_destroy(connection->http_proxy_config);
720         connection->http_proxy_config = NULL;
721     }
722 
723     aws_mqtt_client_release(connection->client);
724 
725     /* Frees all allocated memory */
726     aws_mem_release(connection->allocator, connection);
727 }
728 
s_on_final_disconnect(struct aws_mqtt_client_connection * connection,void * userdata)729 static void s_on_final_disconnect(struct aws_mqtt_client_connection *connection, void *userdata) {
730     (void)userdata;
731 
732     s_mqtt_client_connection_destroy_final(connection);
733 }
734 
s_mqtt_client_connection_start_destroy(struct aws_mqtt_client_connection * connection)735 static void s_mqtt_client_connection_start_destroy(struct aws_mqtt_client_connection *connection) {
736     bool call_destroy_final = false;
737 
738     AWS_LOGF_DEBUG(
739         AWS_LS_MQTT_CLIENT,
740         "id=%p: Last refcount on connection has been released, start destroying the connection.",
741         (void *)connection);
742     { /* BEGIN CRITICAL SECTION */
743         mqtt_connection_lock_synced_data(connection);
744         if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_DISCONNECTED) {
745             /*
746              * We don't call the on_disconnect callback until we've transitioned to the DISCONNECTED state.  So it's
747              * safe to change it now while we hold the lock since we know we're not DISCONNECTED yet.
748              */
749             connection->on_disconnect = s_on_final_disconnect;
750 
751             if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
752                 mqtt_disconnect_impl(connection, AWS_ERROR_SUCCESS);
753                 AWS_LOGF_DEBUG(
754                     AWS_LS_MQTT_CLIENT,
755                     "id=%p: final refcount has been released, switch state to DISCONNECTING.",
756                     (void *)connection);
757                 mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_DISCONNECTING);
758             }
759         } else {
760             call_destroy_final = true;
761         }
762 
763         mqtt_connection_unlock_synced_data(connection);
764     } /* END CRITICAL SECTION */
765 
766     if (call_destroy_final) {
767         s_mqtt_client_connection_destroy_final(connection);
768     }
769 }
770 
aws_mqtt_client_connection_new(struct aws_mqtt_client * client)771 struct aws_mqtt_client_connection *aws_mqtt_client_connection_new(struct aws_mqtt_client *client) {
772     AWS_PRECONDITION(client);
773 
774     struct aws_mqtt_client_connection *connection =
775         aws_mem_calloc(client->allocator, 1, sizeof(struct aws_mqtt_client_connection));
776     if (!connection) {
777         return NULL;
778     }
779 
780     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Creating new connection", (void *)connection);
781 
782     /* Initialize the client */
783     connection->allocator = client->allocator;
784     aws_ref_count_init(
785         &connection->ref_count, connection, (aws_simple_completion_callback *)s_mqtt_client_connection_start_destroy);
786     connection->client = aws_mqtt_client_acquire(client);
787     AWS_ZERO_STRUCT(connection->synced_data);
788     connection->synced_data.state = AWS_MQTT_CLIENT_STATE_DISCONNECTED;
789     connection->reconnect_timeouts.min = 1;
790     connection->reconnect_timeouts.max = 128;
791     aws_linked_list_init(&connection->synced_data.pending_requests_list);
792     aws_linked_list_init(&connection->thread_data.ongoing_requests_list);
793 
794     if (aws_mutex_init(&connection->synced_data.lock)) {
795         AWS_LOGF_ERROR(
796             AWS_LS_MQTT_CLIENT,
797             "id=%p: Failed to initialize mutex, error %d (%s)",
798             (void *)connection,
799             aws_last_error(),
800             aws_error_name(aws_last_error()));
801         goto failed_init_mutex;
802     }
803 
804     if (aws_mqtt_topic_tree_init(&connection->thread_data.subscriptions, connection->allocator)) {
805 
806         AWS_LOGF_ERROR(
807             AWS_LS_MQTT_CLIENT,
808             "id=%p: Failed to initialize subscriptions topic_tree, error %d (%s)",
809             (void *)connection,
810             aws_last_error(),
811             aws_error_name(aws_last_error()));
812         goto failed_init_subscriptions;
813     }
814 
815     if (aws_memory_pool_init(
816             &connection->synced_data.requests_pool, connection->allocator, 32, sizeof(struct aws_mqtt_request))) {
817 
818         AWS_LOGF_ERROR(
819             AWS_LS_MQTT_CLIENT,
820             "id=%p: Failed to initialize request pool, error %d (%s)",
821             (void *)connection,
822             aws_last_error(),
823             aws_error_name(aws_last_error()));
824         goto failed_init_requests_pool;
825     }
826 
827     if (aws_hash_table_init(
828             &connection->synced_data.outstanding_requests_table,
829             connection->allocator,
830             sizeof(struct aws_mqtt_request *),
831             s_hash_uint16_t,
832             s_uint16_t_eq,
833             NULL,
834             NULL)) {
835 
836         AWS_LOGF_ERROR(
837             AWS_LS_MQTT_CLIENT,
838             "id=%p: Failed to initialize outstanding requests table, error %d (%s)",
839             (void *)connection,
840             aws_last_error(),
841             aws_error_name(aws_last_error()));
842         goto failed_init_outstanding_requests_table;
843     }
844 
845     /* Initialize the handler */
846     connection->handler.alloc = connection->allocator;
847     connection->handler.vtable = aws_mqtt_get_client_channel_vtable();
848     connection->handler.impl = connection;
849 
850     return connection;
851 
852 failed_init_outstanding_requests_table:
853     aws_memory_pool_clean_up(&connection->synced_data.requests_pool);
854 
855 failed_init_requests_pool:
856     aws_mqtt_topic_tree_clean_up(&connection->thread_data.subscriptions);
857 
858 failed_init_subscriptions:
859     aws_mutex_clean_up(&connection->synced_data.lock);
860 
861 failed_init_mutex:
862     aws_mem_release(client->allocator, connection);
863 
864     return NULL;
865 }
866 
aws_mqtt_client_connection_acquire(struct aws_mqtt_client_connection * connection)867 struct aws_mqtt_client_connection *aws_mqtt_client_connection_acquire(struct aws_mqtt_client_connection *connection) {
868     if (connection != NULL) {
869         aws_ref_count_acquire(&connection->ref_count);
870     }
871 
872     return connection;
873 }
874 
aws_mqtt_client_connection_release(struct aws_mqtt_client_connection * connection)875 void aws_mqtt_client_connection_release(struct aws_mqtt_client_connection *connection) {
876     if (connection != NULL) {
877         aws_ref_count_release(&connection->ref_count);
878     }
879 }
880 
881 /*******************************************************************************
882  * Connection Configuration
883  ******************************************************************************/
884 
885 /* To configure the connection, ensure the state is DISCONNECTED or CONNECTED */
s_check_connection_state_for_configuration(struct aws_mqtt_client_connection * connection)886 static int s_check_connection_state_for_configuration(struct aws_mqtt_client_connection *connection) {
887     int result = AWS_OP_SUCCESS;
888     { /* BEGIN CRITICAL SECTION */
889         mqtt_connection_lock_synced_data(connection);
890 
891         if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_DISCONNECTED &&
892             connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED) {
893             AWS_LOGF_ERROR(
894                 AWS_LS_MQTT_CLIENT,
895                 "id=%p: Connection is currently pending connect/disconnect. Unable to make configuration changes until "
896                 "pending operation completes.",
897                 (void *)connection);
898             result = AWS_OP_ERR;
899         }
900         mqtt_connection_unlock_synced_data(connection);
901     } /* END CRITICAL SECTION */
902     return result;
903 }
904 
aws_mqtt_client_connection_set_will(struct aws_mqtt_client_connection * connection,const struct aws_byte_cursor * topic,enum aws_mqtt_qos qos,bool retain,const struct aws_byte_cursor * payload)905 int aws_mqtt_client_connection_set_will(
906     struct aws_mqtt_client_connection *connection,
907     const struct aws_byte_cursor *topic,
908     enum aws_mqtt_qos qos,
909     bool retain,
910     const struct aws_byte_cursor *payload) {
911 
912     AWS_PRECONDITION(connection);
913     AWS_PRECONDITION(topic);
914     if (s_check_connection_state_for_configuration(connection)) {
915         return aws_raise_error(AWS_ERROR_INVALID_STATE);
916     }
917 
918     int result = AWS_OP_ERR;
919     AWS_LOGF_TRACE(
920         AWS_LS_MQTT_CLIENT,
921         "id=%p: Setting last will with topic \"" PRInSTR "\"",
922         (void *)connection,
923         AWS_BYTE_CURSOR_PRI(*topic));
924 
925     if (!aws_mqtt_is_valid_topic(topic)) {
926         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Will topic is invalid", (void *)connection);
927         return aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
928     }
929 
930     struct aws_byte_buf local_topic_buf;
931     struct aws_byte_buf local_payload_buf;
932     AWS_ZERO_STRUCT(local_topic_buf);
933     AWS_ZERO_STRUCT(local_payload_buf);
934     struct aws_byte_buf topic_buf = aws_byte_buf_from_array(topic->ptr, topic->len);
935     if (aws_byte_buf_init_copy(&local_topic_buf, connection->allocator, &topic_buf)) {
936         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to copy will topic", (void *)connection);
937         goto cleanup;
938     }
939 
940     connection->will.qos = qos;
941     connection->will.retain = retain;
942 
943     struct aws_byte_buf payload_buf = aws_byte_buf_from_array(payload->ptr, payload->len);
944     if (aws_byte_buf_init_copy(&local_payload_buf, connection->allocator, &payload_buf)) {
945         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to copy will body", (void *)connection);
946         goto cleanup;
947     }
948 
949     if (connection->will.topic.len) {
950         AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Will has been set before, resetting it.", (void *)connection);
951     }
952     /* Succeed. */
953     result = AWS_OP_SUCCESS;
954 
955     /* swap the local buffer with connection */
956     struct aws_byte_buf temp = local_topic_buf;
957     local_topic_buf = connection->will.topic;
958     connection->will.topic = temp;
959     temp = local_payload_buf;
960     local_payload_buf = connection->will.payload;
961     connection->will.payload = temp;
962 
963 cleanup:
964     aws_byte_buf_clean_up(&local_topic_buf);
965     aws_byte_buf_clean_up(&local_payload_buf);
966 
967     return result;
968 }
969 
aws_mqtt_client_connection_set_login(struct aws_mqtt_client_connection * connection,const struct aws_byte_cursor * username,const struct aws_byte_cursor * password)970 int aws_mqtt_client_connection_set_login(
971     struct aws_mqtt_client_connection *connection,
972     const struct aws_byte_cursor *username,
973     const struct aws_byte_cursor *password) {
974 
975     AWS_PRECONDITION(connection);
976     AWS_PRECONDITION(username);
977     if (s_check_connection_state_for_configuration(connection)) {
978         return aws_raise_error(AWS_ERROR_INVALID_STATE);
979     }
980 
981     int result = AWS_OP_ERR;
982     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Setting username and password", (void *)connection);
983 
984     struct aws_string *username_string = NULL;
985     struct aws_string *password_string = NULL;
986 
987     username_string = aws_string_new_from_array(connection->allocator, username->ptr, username->len);
988     if (!username_string) {
989         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to copy username", (void *)connection);
990         goto cleanup;
991     }
992 
993     if (password) {
994         password_string = aws_string_new_from_array(connection->allocator, password->ptr, password->len);
995         if (!password_string) {
996             AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to copy password", (void *)connection);
997             goto cleanup;
998         }
999     }
1000 
1001     if (connection->username) {
1002         AWS_LOGF_TRACE(
1003             AWS_LS_MQTT_CLIENT, "id=%p: Login information has been set before, resetting it.", (void *)connection);
1004     }
1005     /* Succeed. */
1006     result = AWS_OP_SUCCESS;
1007 
1008     /* swap the local string with connection */
1009     struct aws_string *temp = username_string;
1010     username_string = connection->username;
1011     connection->username = temp;
1012     temp = password_string;
1013     password_string = connection->password;
1014     connection->password = temp;
1015 
1016 cleanup:
1017     aws_string_destroy_secure(username_string);
1018     aws_string_destroy_secure(password_string);
1019 
1020     return result;
1021 }
1022 
aws_mqtt_client_connection_set_reconnect_timeout(struct aws_mqtt_client_connection * connection,uint64_t min_timeout,uint64_t max_timeout)1023 int aws_mqtt_client_connection_set_reconnect_timeout(
1024     struct aws_mqtt_client_connection *connection,
1025     uint64_t min_timeout,
1026     uint64_t max_timeout) {
1027 
1028     AWS_PRECONDITION(connection);
1029     if (s_check_connection_state_for_configuration(connection)) {
1030         return aws_raise_error(AWS_ERROR_INVALID_STATE);
1031     }
1032     AWS_LOGF_TRACE(
1033         AWS_LS_MQTT_CLIENT,
1034         "id=%p: Setting reconnect timeouts min: %" PRIu64 " max: %" PRIu64,
1035         (void *)connection,
1036         min_timeout,
1037         max_timeout);
1038     connection->reconnect_timeouts.min = min_timeout;
1039     connection->reconnect_timeouts.max = max_timeout;
1040 
1041     return AWS_OP_SUCCESS;
1042 }
1043 
aws_mqtt_client_connection_set_connection_interruption_handlers(struct aws_mqtt_client_connection * connection,aws_mqtt_client_on_connection_interrupted_fn * on_interrupted,void * on_interrupted_ud,aws_mqtt_client_on_connection_resumed_fn * on_resumed,void * on_resumed_ud)1044 int aws_mqtt_client_connection_set_connection_interruption_handlers(
1045     struct aws_mqtt_client_connection *connection,
1046     aws_mqtt_client_on_connection_interrupted_fn *on_interrupted,
1047     void *on_interrupted_ud,
1048     aws_mqtt_client_on_connection_resumed_fn *on_resumed,
1049     void *on_resumed_ud) {
1050 
1051     AWS_PRECONDITION(connection);
1052     if (s_check_connection_state_for_configuration(connection)) {
1053         return aws_raise_error(AWS_ERROR_INVALID_STATE);
1054     }
1055     AWS_LOGF_TRACE(
1056         AWS_LS_MQTT_CLIENT, "id=%p: Setting connection interrupted and resumed handlers", (void *)connection);
1057 
1058     connection->on_interrupted = on_interrupted;
1059     connection->on_interrupted_ud = on_interrupted_ud;
1060     connection->on_resumed = on_resumed;
1061     connection->on_resumed_ud = on_resumed_ud;
1062 
1063     return AWS_OP_SUCCESS;
1064 }
1065 
aws_mqtt_client_connection_set_on_any_publish_handler(struct aws_mqtt_client_connection * connection,aws_mqtt_client_publish_received_fn * on_any_publish,void * on_any_publish_ud)1066 int aws_mqtt_client_connection_set_on_any_publish_handler(
1067     struct aws_mqtt_client_connection *connection,
1068     aws_mqtt_client_publish_received_fn *on_any_publish,
1069     void *on_any_publish_ud) {
1070 
1071     AWS_PRECONDITION(connection);
1072     { /* BEGIN CRITICAL SECTION */
1073         mqtt_connection_lock_synced_data(connection);
1074 
1075         if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_CONNECTED) {
1076             mqtt_connection_unlock_synced_data(connection);
1077             AWS_LOGF_ERROR(
1078                 AWS_LS_MQTT_CLIENT,
1079                 "id=%p: Connection is connected, publishes may arrive anytime. Unable to set publish handler until "
1080                 "offline.",
1081                 (void *)connection);
1082             return aws_raise_error(AWS_ERROR_INVALID_STATE);
1083         }
1084         mqtt_connection_unlock_synced_data(connection);
1085     } /* END CRITICAL SECTION */
1086 
1087     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Setting on_any_publish handler", (void *)connection);
1088 
1089     connection->on_any_publish = on_any_publish;
1090     connection->on_any_publish_ud = on_any_publish_ud;
1091 
1092     return AWS_OP_SUCCESS;
1093 }
1094 
1095 /*******************************************************************************
1096  * Websockets
1097  ******************************************************************************/
1098 #ifdef AWS_MQTT_WITH_WEBSOCKETS
1099 
aws_mqtt_client_connection_use_websockets(struct aws_mqtt_client_connection * connection,aws_mqtt_transform_websocket_handshake_fn * transformer,void * transformer_ud,aws_mqtt_validate_websocket_handshake_fn * validator,void * validator_ud)1100 int aws_mqtt_client_connection_use_websockets(
1101     struct aws_mqtt_client_connection *connection,
1102     aws_mqtt_transform_websocket_handshake_fn *transformer,
1103     void *transformer_ud,
1104     aws_mqtt_validate_websocket_handshake_fn *validator,
1105     void *validator_ud) {
1106 
1107     connection->websocket.handshake_transformer = transformer;
1108     connection->websocket.handshake_transformer_ud = transformer_ud;
1109     connection->websocket.handshake_validator = validator;
1110     connection->websocket.handshake_validator_ud = validator_ud;
1111     connection->websocket.enabled = true;
1112 
1113     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Using websockets", (void *)connection);
1114 
1115     return AWS_OP_SUCCESS;
1116 }
1117 
aws_mqtt_client_connection_set_http_proxy_options(struct aws_mqtt_client_connection * connection,struct aws_http_proxy_options * proxy_options)1118 int aws_mqtt_client_connection_set_http_proxy_options(
1119     struct aws_mqtt_client_connection *connection,
1120     struct aws_http_proxy_options *proxy_options) {
1121 
1122     /* If there is existing proxy options, nuke em */
1123     if (connection->http_proxy_config) {
1124         aws_http_proxy_config_destroy(connection->http_proxy_config);
1125         connection->http_proxy_config = NULL;
1126     }
1127 
1128     connection->http_proxy_config =
1129         aws_http_proxy_config_new_tunneling_from_proxy_options(connection->allocator, proxy_options);
1130 
1131     return connection->http_proxy_config != NULL ? AWS_OP_SUCCESS : AWS_OP_ERR;
1132 }
1133 
s_on_websocket_shutdown(struct aws_websocket * websocket,int error_code,void * user_data)1134 static void s_on_websocket_shutdown(struct aws_websocket *websocket, int error_code, void *user_data) {
1135     struct aws_mqtt_client_connection *connection = user_data;
1136 
1137     struct aws_channel *channel = connection->slot ? connection->slot->channel : NULL;
1138 
1139     s_mqtt_client_shutdown(connection->client->bootstrap, error_code, channel, connection);
1140 
1141     if (websocket) {
1142         aws_websocket_release(websocket);
1143     }
1144 }
1145 
s_on_websocket_setup(struct aws_websocket * websocket,int error_code,int handshake_response_status,const struct aws_http_header * handshake_response_header_array,size_t num_handshake_response_headers,void * user_data)1146 static void s_on_websocket_setup(
1147     struct aws_websocket *websocket,
1148     int error_code,
1149     int handshake_response_status,
1150     const struct aws_http_header *handshake_response_header_array,
1151     size_t num_handshake_response_headers,
1152     void *user_data) {
1153 
1154     (void)handshake_response_status;
1155 
1156     /* Setup callback contract is: if error_code is non-zero then websocket is NULL. */
1157     AWS_FATAL_ASSERT((error_code != 0) == (websocket == NULL));
1158 
1159     struct aws_mqtt_client_connection *connection = user_data;
1160     struct aws_channel *channel = NULL;
1161 
1162     if (connection->websocket.handshake_request) {
1163         aws_http_message_release(connection->websocket.handshake_request);
1164         connection->websocket.handshake_request = NULL;
1165     }
1166 
1167     if (websocket) {
1168         channel = aws_websocket_get_channel(websocket);
1169         AWS_ASSERT(channel);
1170 
1171         /* Websocket must be "converted" before the MQTT handler can be installed next to it. */
1172         if (aws_websocket_convert_to_midchannel_handler(websocket)) {
1173             AWS_LOGF_ERROR(
1174                 AWS_LS_MQTT_CLIENT,
1175                 "id=%p: Failed converting websocket, error %d (%s)",
1176                 (void *)connection,
1177                 aws_last_error(),
1178                 aws_error_name(aws_last_error()));
1179 
1180             aws_channel_shutdown(channel, aws_last_error());
1181             return;
1182         }
1183 
1184         /* If validation callback is set, let the user accept/reject the handshake */
1185         if (connection->websocket.handshake_validator) {
1186             AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Validating websocket handshake response.", (void *)connection);
1187 
1188             if (connection->websocket.handshake_validator(
1189                     connection,
1190                     handshake_response_header_array,
1191                     num_handshake_response_headers,
1192                     connection->websocket.handshake_validator_ud)) {
1193 
1194                 AWS_LOGF_ERROR(
1195                     AWS_LS_MQTT_CLIENT,
1196                     "id=%p: Failure reported by websocket handshake validator callback, error %d (%s)",
1197                     (void *)connection,
1198                     aws_last_error(),
1199                     aws_error_name(aws_last_error()));
1200 
1201                 aws_channel_shutdown(channel, aws_last_error());
1202                 return;
1203             }
1204 
1205             AWS_LOGF_TRACE(
1206                 AWS_LS_MQTT_CLIENT, "id=%p: Done validating websocket handshake response.", (void *)connection);
1207         }
1208     }
1209 
1210     /* Call into the channel-setup callback, the rest of the logic is the same. */
1211     s_mqtt_client_init(connection->client->bootstrap, error_code, channel, connection);
1212 }
1213 
1214 static aws_mqtt_transform_websocket_handshake_complete_fn s_websocket_handshake_transform_complete; /* fwd declare */
1215 
s_websocket_connect(struct aws_mqtt_client_connection * connection)1216 static int s_websocket_connect(struct aws_mqtt_client_connection *connection) {
1217     AWS_ASSERT(connection->websocket.enabled);
1218 
1219     /* These defaults were chosen because they're commmon in other MQTT libraries.
1220      * The user can modify the request in their transform callback if they need to. */
1221     const struct aws_byte_cursor default_path = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("/mqtt");
1222     const struct aws_http_header default_protocol_header = {
1223         .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("Sec-WebSocket-Protocol"),
1224         .value = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("mqtt"),
1225     };
1226 
1227     /* Build websocket handshake request */
1228     connection->websocket.handshake_request = aws_http_message_new_websocket_handshake_request(
1229         connection->allocator, default_path, aws_byte_cursor_from_string(connection->host_name));
1230 
1231     if (!connection->websocket.handshake_request) {
1232         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to generate websocket handshake request", (void *)connection);
1233         goto error;
1234     }
1235 
1236     if (aws_http_message_add_header(connection->websocket.handshake_request, default_protocol_header)) {
1237         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to generate websocket handshake request", (void *)connection);
1238         goto error;
1239     }
1240 
1241     /* If user registered a transform callback, call it and wait for transform_complete() to be called.
1242      * If no callback registered, call the transform_complete() function ourselves. */
1243     if (connection->websocket.handshake_transformer) {
1244         AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Transforming websocket handshake request.", (void *)connection);
1245 
1246         connection->websocket.handshake_transformer(
1247             connection->websocket.handshake_request,
1248             connection->websocket.handshake_transformer_ud,
1249             s_websocket_handshake_transform_complete,
1250             connection);
1251 
1252     } else {
1253         s_websocket_handshake_transform_complete(
1254             connection->websocket.handshake_request, AWS_ERROR_SUCCESS, connection);
1255     }
1256 
1257     return AWS_OP_SUCCESS;
1258 
1259 error:
1260     aws_http_message_release(connection->websocket.handshake_request);
1261     connection->websocket.handshake_request = NULL;
1262     return AWS_OP_ERR;
1263 }
1264 
s_websocket_handshake_transform_complete(struct aws_http_message * handshake_request,int error_code,void * complete_ctx)1265 static void s_websocket_handshake_transform_complete(
1266     struct aws_http_message *handshake_request,
1267     int error_code,
1268     void *complete_ctx) {
1269 
1270     struct aws_mqtt_client_connection *connection = complete_ctx;
1271 
1272     if (error_code) {
1273         AWS_LOGF_ERROR(
1274             AWS_LS_MQTT_CLIENT,
1275             "id=%p: Failure reported by websocket handshake transform callback.",
1276             (void *)connection);
1277 
1278         goto error;
1279     }
1280 
1281     if (connection->websocket.handshake_transformer) {
1282         AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Done transforming websocket handshake request.", (void *)connection);
1283     }
1284 
1285     /* Call websocket connect() */
1286     struct aws_websocket_client_connection_options websocket_options = {
1287         .allocator = connection->allocator,
1288         .bootstrap = connection->client->bootstrap,
1289         .socket_options = &connection->socket_options,
1290         .tls_options = connection->tls_options.ctx ? &connection->tls_options : NULL,
1291         .host = aws_byte_cursor_from_string(connection->host_name),
1292         .port = connection->port,
1293         .handshake_request = handshake_request,
1294         .initial_window_size = 0, /* Prevent websocket data from arriving before the MQTT handler is installed */
1295         .user_data = connection,
1296         .on_connection_setup = s_on_websocket_setup,
1297         .on_connection_shutdown = s_on_websocket_shutdown,
1298     };
1299 
1300     struct aws_http_proxy_options proxy_options;
1301     AWS_ZERO_STRUCT(proxy_options);
1302     if (connection->http_proxy_config != NULL) {
1303         aws_http_proxy_options_init_from_config(&proxy_options, connection->http_proxy_config);
1304         websocket_options.proxy_options = &proxy_options;
1305     }
1306 
1307     if (aws_websocket_client_connect(&websocket_options)) {
1308         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to initiate websocket connection.", (void *)connection);
1309         error_code = aws_last_error();
1310         goto error;
1311     }
1312 
1313     /* Success */
1314     return;
1315 
1316 error:
1317     /* Proceed to next step, telling it that we failed. */
1318     s_on_websocket_setup(NULL, error_code, -1, NULL, 0, connection);
1319 }
1320 
1321 #else  /* AWS_MQTT_WITH_WEBSOCKETS */
aws_mqtt_client_connection_use_websockets(struct aws_mqtt_client_connection * connection,aws_mqtt_transform_websocket_handshake_fn * transformer,void * transformer_ud,aws_mqtt_validate_websocket_handshake_fn * validator,void * validator_ud)1322 int aws_mqtt_client_connection_use_websockets(
1323     struct aws_mqtt_client_connection *connection,
1324     aws_mqtt_transform_websocket_handshake_fn *transformer,
1325     void *transformer_ud,
1326     aws_mqtt_validate_websocket_handshake_fn *validator,
1327     void *validator_ud) {
1328 
1329     (void)connection;
1330     (void)transformer;
1331     (void)transformer_ud;
1332     (void)validator;
1333     (void)validator_ud;
1334 
1335     AWS_LOGF_ERROR(
1336         AWS_LS_MQTT_CLIENT,
1337         "id=%p: Cannot use websockets unless library is built with MQTT_WITH_WEBSOCKETS option.",
1338         (void *)connection);
1339 
1340     return aws_raise_error(AWS_ERROR_MQTT_BUILT_WITHOUT_WEBSOCKETS);
1341 }
1342 
aws_mqtt_client_connection_set_websocket_proxy_options(struct aws_mqtt_client_connection * connection,struct aws_http_proxy_options * proxy_options)1343 int aws_mqtt_client_connection_set_websocket_proxy_options(
1344     struct aws_mqtt_client_connection *connection,
1345     struct aws_http_proxy_options *proxy_options) {
1346 
1347     (void)connection;
1348     (void)proxy_options;
1349 
1350     AWS_LOGF_ERROR(
1351         AWS_LS_MQTT_CLIENT,
1352         "id=%p: Cannot use websockets unless library is built with MQTT_WITH_WEBSOCKETS option.",
1353         (void *)connection);
1354 
1355     return aws_raise_error(AWS_ERROR_MQTT_BUILT_WITHOUT_WEBSOCKETS);
1356 }
1357 #endif /* AWS_MQTT_WITH_WEBSOCKETS */
1358 
1359 /*******************************************************************************
1360  * Connect
1361  ******************************************************************************/
1362 
aws_mqtt_client_connection_connect(struct aws_mqtt_client_connection * connection,const struct aws_mqtt_connection_options * connection_options)1363 int aws_mqtt_client_connection_connect(
1364     struct aws_mqtt_client_connection *connection,
1365     const struct aws_mqtt_connection_options *connection_options) {
1366 
1367     /* TODO: Do we need to support resuming the connection if user connect to the same connection & endpoint and the
1368      * clean_session is false?
1369      * If not, the broker will resume the connection in this case, and we pretend we are making a new connection, which
1370      * may cause some confusing behavior. This is basically what we have now. NOTE: The topic_tree is living with the
1371      * connection right now, which is really confusing.
1372      * If yes, an edge case will be: User disconnected from the connection with clean_session
1373      * being false, then connect to another endpoint with the same connection object, we probably need to clear all
1374      * those states from last connection and create a new "connection". Problem is what if user finish the second
1375      * connection and reconnect to the first endpoint. There is no way for us to resume the connection in this case. */
1376 
1377     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Opening connection", (void *)connection);
1378     { /* BEGIN CRITICAL SECTION */
1379         mqtt_connection_lock_synced_data(connection);
1380 
1381         if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_DISCONNECTED) {
1382             mqtt_connection_unlock_synced_data(connection);
1383             return aws_raise_error(AWS_ERROR_MQTT_ALREADY_CONNECTED);
1384         }
1385         mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_CONNECTING);
1386         AWS_LOGF_DEBUG(
1387             AWS_LS_MQTT_CLIENT, "id=%p: Begin connecting process, switch state to CONNECTING.", (void *)connection);
1388         mqtt_connection_unlock_synced_data(connection);
1389     } /* END CRITICAL SECTION */
1390 
1391     if (connection->host_name) {
1392         aws_string_destroy(connection->host_name);
1393     }
1394 
1395     connection->host_name = aws_string_new_from_array(
1396         connection->allocator, connection_options->host_name.ptr, connection_options->host_name.len);
1397     connection->port = connection_options->port;
1398     connection->socket_options = *connection_options->socket_options;
1399     connection->clean_session = connection_options->clean_session;
1400     connection->keep_alive_time_secs = connection_options->keep_alive_time_secs;
1401     connection->connection_count = 0;
1402 
1403     if (!connection->keep_alive_time_secs) {
1404         connection->keep_alive_time_secs = s_default_keep_alive_sec;
1405     }
1406     if (!connection_options->protocol_operation_timeout_ms) {
1407         connection->operation_timeout_ns = UINT64_MAX;
1408     } else {
1409         connection->operation_timeout_ns = aws_timestamp_convert(
1410             (uint64_t)connection_options->protocol_operation_timeout_ms,
1411             AWS_TIMESTAMP_MILLIS,
1412             AWS_TIMESTAMP_NANOS,
1413             NULL);
1414     }
1415 
1416     if (!connection_options->ping_timeout_ms) {
1417         connection->ping_timeout_ns = s_default_ping_timeout_ns;
1418     } else {
1419         connection->ping_timeout_ns = aws_timestamp_convert(
1420             (uint64_t)connection_options->ping_timeout_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL);
1421     }
1422 
1423     /* Keep alive time should always be greater than the timeouts. */
1424     if (AWS_UNLIKELY(connection->keep_alive_time_secs * (uint64_t)AWS_TIMESTAMP_NANOS <= connection->ping_timeout_ns)) {
1425         AWS_LOGF_FATAL(
1426             AWS_LS_MQTT_CLIENT,
1427             "id=%p: Illegal configuration, Connection keep alive %" PRIu64
1428             "ns must be greater than the request timeouts %" PRIu64 "ns.",
1429             (void *)connection,
1430             (uint64_t)connection->keep_alive_time_secs * (uint64_t)AWS_TIMESTAMP_NANOS,
1431             connection->ping_timeout_ns);
1432         AWS_FATAL_ASSERT(
1433             connection->keep_alive_time_secs * (uint64_t)AWS_TIMESTAMP_NANOS > connection->ping_timeout_ns);
1434     }
1435 
1436     AWS_LOGF_INFO(
1437         AWS_LS_MQTT_CLIENT,
1438         "id=%p: using ping timeout of %" PRIu64 " ns",
1439         (void *)connection,
1440         connection->ping_timeout_ns);
1441 
1442     /* Cheat and set the tls_options host_name to our copy if they're the same */
1443     if (connection_options->tls_options) {
1444         connection->use_tls = true;
1445         if (aws_tls_connection_options_copy(&connection->tls_options, connection_options->tls_options)) {
1446 
1447             AWS_LOGF_ERROR(
1448                 AWS_LS_MQTT_CLIENT, "id=%p: Failed to copy TLS Connection Options into connection", (void *)connection);
1449             return AWS_OP_ERR;
1450         }
1451 
1452         if (!connection_options->tls_options->server_name) {
1453             struct aws_byte_cursor host_name_cur = aws_byte_cursor_from_string(connection->host_name);
1454             if (aws_tls_connection_options_set_server_name(
1455                     &connection->tls_options, connection->allocator, &host_name_cur)) {
1456 
1457                 AWS_LOGF_ERROR(
1458                     AWS_LS_MQTT_CLIENT, "id=%p: Failed to set TLS Connection Options server name", (void *)connection);
1459                 goto error;
1460             }
1461         }
1462 
1463     } else {
1464         AWS_ZERO_STRUCT(connection->tls_options);
1465     }
1466 
1467     /* Clean up old client_id */
1468     if (connection->client_id.buffer) {
1469         aws_byte_buf_clean_up(&connection->client_id);
1470     }
1471 
1472     /* Only set connection->client_id if a new one was provided */
1473     struct aws_byte_buf client_id_buf =
1474         aws_byte_buf_from_array(connection_options->client_id.ptr, connection_options->client_id.len);
1475     if (aws_byte_buf_init_copy(&connection->client_id, connection->allocator, &client_id_buf)) {
1476         AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Failed to copy client_id into connection", (void *)connection);
1477         goto error;
1478     }
1479 
1480     struct aws_linked_list cancelling_requests;
1481     aws_linked_list_init(&cancelling_requests);
1482     { /* BEGIN CRITICAL SECTION */
1483         mqtt_connection_lock_synced_data(connection);
1484         if (connection->clean_session) {
1485             AWS_LOGF_TRACE(
1486                 AWS_LS_MQTT_CLIENT,
1487                 "id=%p: a clean session connection requested, all the previous requests will fail",
1488                 (void *)connection);
1489             aws_linked_list_swap_contents(&connection->synced_data.pending_requests_list, &cancelling_requests);
1490         }
1491         mqtt_connection_unlock_synced_data(connection);
1492     } /* END CRITICAL SECTION */
1493 
1494     if (!aws_linked_list_empty(&cancelling_requests)) {
1495 
1496         struct aws_linked_list_node *current = aws_linked_list_front(&cancelling_requests);
1497         const struct aws_linked_list_node *end = aws_linked_list_end(&cancelling_requests);
1498         /* invoke all the complete callback for requests from previous session */
1499         while (current != end) {
1500             struct aws_mqtt_request *request = AWS_CONTAINER_OF(current, struct aws_mqtt_request, list_node);
1501             AWS_LOGF_TRACE(
1502                 AWS_LS_MQTT_CLIENT,
1503                 "id=%p: Establishing a new clean session connection, discard the previous request %" PRIu16,
1504                 (void *)connection,
1505                 request->packet_id);
1506             if (request->on_complete) {
1507                 request->on_complete(
1508                     connection,
1509                     request->packet_id,
1510                     AWS_ERROR_MQTT_CANCELLED_FOR_CLEAN_SESSION,
1511                     request->on_complete_ud);
1512             }
1513             current = current->next;
1514         }
1515         /* free the resource */
1516         { /* BEGIN CRITICAL SECTION */
1517             mqtt_connection_lock_synced_data(connection);
1518             while (!aws_linked_list_empty(&cancelling_requests)) {
1519                 struct aws_linked_list_node *node = aws_linked_list_pop_front(&cancelling_requests);
1520                 struct aws_mqtt_request *request = AWS_CONTAINER_OF(node, struct aws_mqtt_request, list_node);
1521                 aws_hash_table_remove(
1522                     &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
1523                 aws_memory_pool_release(&connection->synced_data.requests_pool, request);
1524             }
1525             mqtt_connection_unlock_synced_data(connection);
1526         } /* END CRITICAL SECTION */
1527     }
1528 
1529     /* Begin the connecting process, acquire the connection to keep it alive until we disconnected */
1530     aws_mqtt_client_connection_acquire(connection);
1531 
1532     if (s_mqtt_client_connect(connection, connection_options->on_connection_complete, connection_options->user_data)) {
1533         /*
1534          * An error calling s_mqtt_client_connect should (must) be mutually exclusive with s_mqtt_client_shutdown().
1535          * So it should be safe and correct to call release now to undo the pinning we did a few lines above.
1536          */
1537         aws_mqtt_client_connection_release(connection);
1538 
1539         /* client_id has been updated with something but it will get cleaned up when the connection gets cleaned up
1540          * so we don't need to worry about it here*/
1541         if (connection->clean_session) {
1542             AWS_LOGF_WARN(
1543                 AWS_LS_MQTT_CLIENT, "id=%p: The previous session has been cleaned up and losted!", (void *)connection);
1544         }
1545         goto error;
1546     }
1547 
1548     return AWS_OP_SUCCESS;
1549 
1550 error:
1551     aws_tls_connection_options_clean_up(&connection->tls_options);
1552     AWS_ZERO_STRUCT(connection->tls_options);
1553     { /* BEGIN CRITICAL SECTION */
1554         mqtt_connection_lock_synced_data(connection);
1555         mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_DISCONNECTED);
1556         mqtt_connection_unlock_synced_data(connection);
1557     } /* END CRITICAL SECTION */
1558     return AWS_OP_ERR;
1559 }
1560 
s_mqtt_client_connect(struct aws_mqtt_client_connection * connection,aws_mqtt_client_on_connection_complete_fn * on_connection_complete,void * userdata)1561 static int s_mqtt_client_connect(
1562     struct aws_mqtt_client_connection *connection,
1563     aws_mqtt_client_on_connection_complete_fn *on_connection_complete,
1564     void *userdata) {
1565     connection->on_connection_complete = on_connection_complete;
1566     connection->on_connection_complete_ud = userdata;
1567 
1568     int result = 0;
1569 #ifdef AWS_MQTT_WITH_WEBSOCKETS
1570     if (connection->websocket.enabled) {
1571         result = s_websocket_connect(connection);
1572     } else
1573 #endif /* AWS_MQTT_WITH_WEBSOCKETS */
1574     {
1575         struct aws_socket_channel_bootstrap_options channel_options;
1576         AWS_ZERO_STRUCT(channel_options);
1577         channel_options.bootstrap = connection->client->bootstrap;
1578         channel_options.host_name = aws_string_c_str(connection->host_name);
1579         channel_options.port = connection->port;
1580         channel_options.socket_options = &connection->socket_options;
1581         channel_options.tls_options = connection->use_tls ? &connection->tls_options : NULL;
1582         channel_options.setup_callback = &s_mqtt_client_init;
1583         channel_options.shutdown_callback = &s_mqtt_client_shutdown;
1584         channel_options.user_data = connection;
1585 
1586         if (connection->http_proxy_config == NULL) {
1587             result = aws_client_bootstrap_new_socket_channel(&channel_options);
1588         } else {
1589             struct aws_http_proxy_options proxy_options;
1590             AWS_ZERO_STRUCT(proxy_options);
1591 
1592             aws_http_proxy_options_init_from_config(&proxy_options, connection->http_proxy_config);
1593             result = aws_http_proxy_new_socket_channel(&channel_options, &proxy_options);
1594         }
1595     }
1596 
1597     if (result) {
1598         /* Connection attempt failed */
1599         AWS_LOGF_ERROR(
1600             AWS_LS_MQTT_CLIENT,
1601             "id=%p: Failed to begin connection routine, error %d (%s).",
1602             (void *)connection,
1603             aws_last_error(),
1604             aws_error_name(aws_last_error()));
1605         return AWS_OP_ERR;
1606     }
1607 
1608     return AWS_OP_SUCCESS;
1609 }
1610 
1611 /*******************************************************************************
1612  * Reconnect  DEPRECATED
1613  ******************************************************************************/
1614 
aws_mqtt_client_connection_reconnect(struct aws_mqtt_client_connection * connection,aws_mqtt_client_on_connection_complete_fn * on_connection_complete,void * userdata)1615 int aws_mqtt_client_connection_reconnect(
1616     struct aws_mqtt_client_connection *connection,
1617     aws_mqtt_client_on_connection_complete_fn *on_connection_complete,
1618     void *userdata) {
1619     (void)connection;
1620     (void)on_connection_complete;
1621     (void)userdata;
1622 
1623     /* DEPRECATED, connection will reconnect automatically now. */
1624     AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "aws_mqtt_client_connection_reconnect has been DEPRECATED.");
1625     return aws_raise_error(AWS_ERROR_UNSUPPORTED_OPERATION);
1626 }
1627 
1628 /*******************************************************************************
1629  * Disconnect
1630  ******************************************************************************/
1631 
aws_mqtt_client_connection_disconnect(struct aws_mqtt_client_connection * connection,aws_mqtt_client_on_disconnect_fn * on_disconnect,void * userdata)1632 int aws_mqtt_client_connection_disconnect(
1633     struct aws_mqtt_client_connection *connection,
1634     aws_mqtt_client_on_disconnect_fn *on_disconnect,
1635     void *userdata) {
1636 
1637     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: user called disconnect.", (void *)connection);
1638 
1639     { /* BEGIN CRITICAL SECTION */
1640         mqtt_connection_lock_synced_data(connection);
1641 
1642         if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED &&
1643             connection->synced_data.state != AWS_MQTT_CLIENT_STATE_RECONNECTING) {
1644             mqtt_connection_unlock_synced_data(connection);
1645             AWS_LOGF_ERROR(
1646                 AWS_LS_MQTT_CLIENT, "id=%p: Connection is not open, and may not be closed", (void *)connection);
1647             aws_raise_error(AWS_ERROR_MQTT_NOT_CONNECTED);
1648             return AWS_OP_ERR;
1649         }
1650         mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_DISCONNECTING);
1651         AWS_LOGF_DEBUG(
1652             AWS_LS_MQTT_CLIENT,
1653             "id=%p: User requests disconnecting, switch state to DISCONNECTING.",
1654             (void *)connection);
1655         connection->on_disconnect = on_disconnect;
1656         connection->on_disconnect_ud = userdata;
1657         mqtt_connection_unlock_synced_data(connection);
1658     } /* END CRITICAL SECTION */
1659 
1660     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Closing connection", (void *)connection);
1661 
1662     mqtt_disconnect_impl(connection, AWS_OP_SUCCESS);
1663 
1664     return AWS_OP_SUCCESS;
1665 }
1666 
1667 /*******************************************************************************
1668  * Subscribe
1669  ******************************************************************************/
1670 
s_on_publish_client_wrapper(const struct aws_byte_cursor * topic,const struct aws_byte_cursor * payload,bool dup,enum aws_mqtt_qos qos,bool retain,void * userdata)1671 static void s_on_publish_client_wrapper(
1672     const struct aws_byte_cursor *topic,
1673     const struct aws_byte_cursor *payload,
1674     bool dup,
1675     enum aws_mqtt_qos qos,
1676     bool retain,
1677     void *userdata) {
1678 
1679     struct subscribe_task_topic *task_topic = userdata;
1680 
1681     /* Call out to the user callback */
1682     if (task_topic->request.on_publish) {
1683         task_topic->request.on_publish(
1684             task_topic->connection, topic, payload, dup, qos, retain, task_topic->request.on_publish_ud);
1685     }
1686 }
1687 
s_task_topic_release(void * userdata)1688 static void s_task_topic_release(void *userdata) {
1689     struct subscribe_task_topic *task_topic = userdata;
1690     if (task_topic != NULL) {
1691         aws_ref_count_release(&task_topic->ref_count);
1692     }
1693 }
1694 
s_task_topic_clean_up(void * userdata)1695 static void s_task_topic_clean_up(void *userdata) {
1696 
1697     struct subscribe_task_topic *task_topic = userdata;
1698 
1699     if (task_topic->request.on_cleanup) {
1700         task_topic->request.on_cleanup(task_topic->request.on_publish_ud);
1701     }
1702     aws_string_destroy(task_topic->filter);
1703     aws_mem_release(task_topic->connection->allocator, task_topic);
1704 }
1705 
s_subscribe_send(uint16_t packet_id,bool is_first_attempt,void * userdata)1706 static enum aws_mqtt_client_request_state s_subscribe_send(uint16_t packet_id, bool is_first_attempt, void *userdata) {
1707 
1708     (void)is_first_attempt;
1709 
1710     struct subscribe_task_arg *task_arg = userdata;
1711     bool initing_packet = task_arg->subscribe.fixed_header.packet_type == 0;
1712     struct aws_io_message *message = NULL;
1713 
1714     AWS_LOGF_TRACE(
1715         AWS_LS_MQTT_CLIENT,
1716         "id=%p: Attempting send of subscribe %" PRIu16 " (%s)",
1717         (void *)task_arg->connection,
1718         packet_id,
1719         is_first_attempt ? "first attempt" : "resend");
1720 
1721     if (initing_packet) {
1722         /* Init the subscribe packet */
1723         if (aws_mqtt_packet_subscribe_init(&task_arg->subscribe, task_arg->connection->allocator, packet_id)) {
1724             return AWS_MQTT_CLIENT_REQUEST_ERROR;
1725         }
1726     }
1727 
1728     const size_t num_topics = aws_array_list_length(&task_arg->topics);
1729     if (num_topics <= 0) {
1730         aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
1731         return AWS_MQTT_CLIENT_REQUEST_ERROR;
1732     }
1733 
1734     AWS_VARIABLE_LENGTH_ARRAY(uint8_t, transaction_buf, num_topics * aws_mqtt_topic_tree_action_size);
1735     struct aws_array_list transaction;
1736     aws_array_list_init_static(&transaction, transaction_buf, num_topics, aws_mqtt_topic_tree_action_size);
1737 
1738     for (size_t i = 0; i < num_topics; ++i) {
1739 
1740         struct subscribe_task_topic *topic = NULL;
1741         aws_array_list_get_at(&task_arg->topics, &topic, i);
1742         AWS_ASSUME(topic); /* We know we're within bounds */
1743 
1744         if (initing_packet) {
1745             if (aws_mqtt_packet_subscribe_add_topic(&task_arg->subscribe, topic->request.topic, topic->request.qos)) {
1746                 goto handle_error;
1747             }
1748         }
1749 
1750         if (!task_arg->tree_updated) {
1751             if (aws_mqtt_topic_tree_transaction_insert(
1752                     &task_arg->connection->thread_data.subscriptions,
1753                     &transaction,
1754                     topic->filter,
1755                     topic->request.qos,
1756                     s_on_publish_client_wrapper,
1757                     s_task_topic_release,
1758                     topic)) {
1759 
1760                 goto handle_error;
1761             }
1762             /* If insert succeed, acquire the refcount */
1763             aws_ref_count_acquire(&topic->ref_count);
1764         }
1765     }
1766 
1767     message = mqtt_get_message_for_packet(task_arg->connection, &task_arg->subscribe.fixed_header);
1768     if (!message) {
1769 
1770         goto handle_error;
1771     }
1772 
1773     if (aws_mqtt_packet_subscribe_encode(&message->message_data, &task_arg->subscribe)) {
1774 
1775         goto handle_error;
1776     }
1777 
1778     /* This is not necessarily a fatal error; if the subscribe fails, it'll just retry. Still need to clean up though.
1779      */
1780     if (aws_channel_slot_send_message(task_arg->connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
1781         aws_mem_release(message->allocator, message);
1782     }
1783 
1784     if (!task_arg->tree_updated) {
1785         aws_mqtt_topic_tree_transaction_commit(&task_arg->connection->thread_data.subscriptions, &transaction);
1786         task_arg->tree_updated = true;
1787     }
1788 
1789     aws_array_list_clean_up(&transaction);
1790     return AWS_MQTT_CLIENT_REQUEST_ONGOING;
1791 
1792 handle_error:
1793 
1794     if (message) {
1795         aws_mem_release(message->allocator, message);
1796     }
1797     if (!task_arg->tree_updated) {
1798         aws_mqtt_topic_tree_transaction_roll_back(&task_arg->connection->thread_data.subscriptions, &transaction);
1799     }
1800 
1801     aws_array_list_clean_up(&transaction);
1802     return AWS_MQTT_CLIENT_REQUEST_ERROR;
1803 }
1804 
s_subscribe_complete(struct aws_mqtt_client_connection * connection,uint16_t packet_id,int error_code,void * userdata)1805 static void s_subscribe_complete(
1806     struct aws_mqtt_client_connection *connection,
1807     uint16_t packet_id,
1808     int error_code,
1809     void *userdata) {
1810 
1811     struct subscribe_task_arg *task_arg = userdata;
1812 
1813     struct subscribe_task_topic *topic = NULL;
1814     aws_array_list_get_at(&task_arg->topics, &topic, 0);
1815     AWS_ASSUME(topic);
1816 
1817     AWS_LOGF_DEBUG(
1818         AWS_LS_MQTT_CLIENT,
1819         "id=%p: Subscribe %" PRIu16 " completed with error_code %d",
1820         (void *)connection,
1821         packet_id,
1822         error_code);
1823 
1824     size_t list_len = aws_array_list_length(&task_arg->topics);
1825     if (task_arg->on_suback.multi) {
1826         /* create a list of aws_mqtt_topic_subscription pointers from topics for the callback */
1827         AWS_VARIABLE_LENGTH_ARRAY(uint8_t, cb_list_buf, list_len * sizeof(void *));
1828         struct aws_array_list cb_list;
1829         aws_array_list_init_static(&cb_list, cb_list_buf, list_len, sizeof(void *));
1830         int err = 0;
1831         for (size_t i = 0; i < list_len; i++) {
1832             err |= aws_array_list_get_at(&task_arg->topics, &topic, i);
1833             struct aws_mqtt_topic_subscription *subscription = &topic->request;
1834             err |= aws_array_list_push_back(&cb_list, &subscription);
1835         }
1836         AWS_ASSUME(!err);
1837         task_arg->on_suback.multi(connection, packet_id, &cb_list, error_code, task_arg->on_suback_ud);
1838         aws_array_list_clean_up(&cb_list);
1839     } else if (task_arg->on_suback.single) {
1840         task_arg->on_suback.single(
1841             connection, packet_id, &topic->request.topic, topic->request.qos, error_code, task_arg->on_suback_ud);
1842     }
1843     for (size_t i = 0; i < list_len; i++) {
1844         aws_array_list_get_at(&task_arg->topics, &topic, i);
1845         s_task_topic_release(topic);
1846     }
1847     aws_array_list_clean_up(&task_arg->topics);
1848     aws_mqtt_packet_subscribe_clean_up(&task_arg->subscribe);
1849     aws_mem_release(task_arg->connection->allocator, task_arg);
1850 }
1851 
aws_mqtt_client_connection_subscribe_multiple(struct aws_mqtt_client_connection * connection,const struct aws_array_list * topic_filters,aws_mqtt_suback_multi_fn * on_suback,void * on_suback_ud)1852 uint16_t aws_mqtt_client_connection_subscribe_multiple(
1853     struct aws_mqtt_client_connection *connection,
1854     const struct aws_array_list *topic_filters,
1855     aws_mqtt_suback_multi_fn *on_suback,
1856     void *on_suback_ud) {
1857 
1858     AWS_PRECONDITION(connection);
1859 
1860     struct subscribe_task_arg *task_arg = aws_mem_calloc(connection->allocator, 1, sizeof(struct subscribe_task_arg));
1861     if (!task_arg) {
1862         return 0;
1863     }
1864 
1865     task_arg->connection = connection;
1866     task_arg->on_suback.multi = on_suback;
1867     task_arg->on_suback_ud = on_suback_ud;
1868 
1869     const size_t num_topics = aws_array_list_length(topic_filters);
1870 
1871     if (aws_array_list_init_dynamic(&task_arg->topics, connection->allocator, num_topics, sizeof(void *))) {
1872         goto handle_error;
1873     }
1874 
1875     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Starting multi-topic subscribe", (void *)connection);
1876 
1877     for (size_t i = 0; i < num_topics; ++i) {
1878 
1879         struct aws_mqtt_topic_subscription *request = NULL;
1880         aws_array_list_get_at_ptr(topic_filters, (void **)&request, i);
1881 
1882         if (!aws_mqtt_is_valid_topic_filter(&request->topic)) {
1883             aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
1884             goto handle_error;
1885         }
1886 
1887         struct subscribe_task_topic *task_topic =
1888             aws_mem_calloc(connection->allocator, 1, sizeof(struct subscribe_task_topic));
1889         if (!task_topic) {
1890             goto handle_error;
1891         }
1892         aws_ref_count_init(&task_topic->ref_count, task_topic, (aws_simple_completion_callback *)s_task_topic_clean_up);
1893 
1894         task_topic->connection = connection;
1895         task_topic->request = *request;
1896 
1897         task_topic->filter = aws_string_new_from_array(
1898             connection->allocator, task_topic->request.topic.ptr, task_topic->request.topic.len);
1899         if (!task_topic->filter) {
1900             aws_mem_release(connection->allocator, task_topic);
1901             goto handle_error;
1902         }
1903 
1904         /* Update request topic cursor to refer to owned string */
1905         task_topic->request.topic = aws_byte_cursor_from_string(task_topic->filter);
1906 
1907         AWS_LOGF_DEBUG(
1908             AWS_LS_MQTT_CLIENT,
1909             "id=%p:     Adding topic \"" PRInSTR "\"",
1910             (void *)connection,
1911             AWS_BYTE_CURSOR_PRI(task_topic->request.topic));
1912 
1913         /* Push into the list */
1914         aws_array_list_push_back(&task_arg->topics, &task_topic);
1915     }
1916 
1917     uint16_t packet_id = mqtt_create_request(
1918         task_arg->connection, &s_subscribe_send, task_arg, &s_subscribe_complete, task_arg, false /* noRetry */);
1919 
1920     if (packet_id == 0) {
1921         AWS_LOGF_ERROR(
1922             AWS_LS_MQTT_CLIENT,
1923             "id=%p: Failed to kick off multi-topic subscribe, with error %s",
1924             (void *)connection,
1925             aws_error_debug_str(aws_last_error()));
1926         goto handle_error;
1927     }
1928 
1929     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Sending multi-topic subscribe %" PRIu16, (void *)connection, packet_id);
1930     return packet_id;
1931 
1932 handle_error:
1933 
1934     if (task_arg) {
1935 
1936         if (task_arg->topics.data) {
1937 
1938             const size_t num_added_topics = aws_array_list_length(&task_arg->topics);
1939             for (size_t i = 0; i < num_added_topics; ++i) {
1940 
1941                 struct subscribe_task_topic *task_topic = NULL;
1942                 aws_array_list_get_at(&task_arg->topics, (void **)&task_topic, i);
1943                 AWS_ASSUME(task_topic);
1944 
1945                 aws_string_destroy(task_topic->filter);
1946                 aws_mem_release(connection->allocator, task_topic);
1947             }
1948 
1949             aws_array_list_clean_up(&task_arg->topics);
1950         }
1951 
1952         aws_mem_release(connection->allocator, task_arg);
1953     }
1954     return 0;
1955 }
1956 
1957 /*******************************************************************************
1958  * Subscribe Single
1959  ******************************************************************************/
1960 
s_subscribe_single_complete(struct aws_mqtt_client_connection * connection,uint16_t packet_id,int error_code,void * userdata)1961 static void s_subscribe_single_complete(
1962     struct aws_mqtt_client_connection *connection,
1963     uint16_t packet_id,
1964     int error_code,
1965     void *userdata) {
1966 
1967     struct subscribe_task_arg *task_arg = userdata;
1968 
1969     AWS_LOGF_DEBUG(
1970         AWS_LS_MQTT_CLIENT,
1971         "id=%p: Subscribe %" PRIu16 " completed with error code %d",
1972         (void *)connection,
1973         packet_id,
1974         error_code);
1975 
1976     AWS_ASSERT(aws_array_list_length(&task_arg->topics) == 1);
1977     struct subscribe_task_topic *topic = NULL;
1978     aws_array_list_get_at(&task_arg->topics, &topic, 0);
1979     AWS_ASSUME(topic); /* There needs to be exactly 1 topic in this list */
1980     if (task_arg->on_suback.single) {
1981         AWS_ASSUME(aws_string_is_valid(topic->filter));
1982         aws_mqtt_suback_fn *suback = task_arg->on_suback.single;
1983         suback(connection, packet_id, &topic->request.topic, topic->request.qos, error_code, task_arg->on_suback_ud);
1984     }
1985     s_task_topic_release(topic);
1986     aws_array_list_clean_up(&task_arg->topics);
1987     aws_mqtt_packet_subscribe_clean_up(&task_arg->subscribe);
1988     aws_mem_release(task_arg->connection->allocator, task_arg);
1989 }
1990 
aws_mqtt_client_connection_subscribe(struct aws_mqtt_client_connection * connection,const struct aws_byte_cursor * topic_filter,enum aws_mqtt_qos qos,aws_mqtt_client_publish_received_fn * on_publish,void * on_publish_ud,aws_mqtt_userdata_cleanup_fn * on_ud_cleanup,aws_mqtt_suback_fn * on_suback,void * on_suback_ud)1991 uint16_t aws_mqtt_client_connection_subscribe(
1992     struct aws_mqtt_client_connection *connection,
1993     const struct aws_byte_cursor *topic_filter,
1994     enum aws_mqtt_qos qos,
1995     aws_mqtt_client_publish_received_fn *on_publish,
1996     void *on_publish_ud,
1997     aws_mqtt_userdata_cleanup_fn *on_ud_cleanup,
1998     aws_mqtt_suback_fn *on_suback,
1999     void *on_suback_ud) {
2000 
2001     AWS_PRECONDITION(connection);
2002 
2003     if (!aws_mqtt_is_valid_topic_filter(topic_filter)) {
2004         aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
2005         return 0;
2006     }
2007 
2008     /* Because we know we're only going to have 1 topic, we can cheat and allocate the array_list in the same block as
2009      * the task argument. */
2010     void *task_topic_storage = NULL;
2011     struct subscribe_task_topic *task_topic = NULL;
2012     struct subscribe_task_arg *task_arg = aws_mem_acquire_many(
2013         connection->allocator,
2014         2,
2015         &task_arg,
2016         sizeof(struct subscribe_task_arg),
2017         &task_topic_storage,
2018         sizeof(struct subscribe_task_topic *));
2019 
2020     if (!task_arg) {
2021         goto handle_error;
2022     }
2023     AWS_ZERO_STRUCT(*task_arg);
2024 
2025     task_arg->connection = connection;
2026     task_arg->on_suback.single = on_suback;
2027     task_arg->on_suback_ud = on_suback_ud;
2028 
2029     /* It stores the pointer */
2030     aws_array_list_init_static(&task_arg->topics, task_topic_storage, 1, sizeof(void *));
2031 
2032     /* Allocate the topic and push into the list */
2033     task_topic = aws_mem_calloc(connection->allocator, 1, sizeof(struct subscribe_task_topic));
2034     if (!task_topic) {
2035         goto handle_error;
2036     }
2037     aws_ref_count_init(&task_topic->ref_count, task_topic, (aws_simple_completion_callback *)s_task_topic_clean_up);
2038     aws_array_list_push_back(&task_arg->topics, &task_topic);
2039 
2040     task_topic->filter = aws_string_new_from_array(connection->allocator, topic_filter->ptr, topic_filter->len);
2041     if (!task_topic->filter) {
2042         goto handle_error;
2043     }
2044 
2045     task_topic->connection = connection;
2046     task_topic->request.topic = aws_byte_cursor_from_string(task_topic->filter);
2047     task_topic->request.qos = qos;
2048     task_topic->request.on_publish = on_publish;
2049     task_topic->request.on_cleanup = on_ud_cleanup;
2050     task_topic->request.on_publish_ud = on_publish_ud;
2051 
2052     uint16_t packet_id = mqtt_create_request(
2053         task_arg->connection, &s_subscribe_send, task_arg, &s_subscribe_single_complete, task_arg, false /* noRetry */);
2054 
2055     if (packet_id == 0) {
2056         AWS_LOGF_ERROR(
2057             AWS_LS_MQTT_CLIENT,
2058             "id=%p: Failed to start subscribe on topic " PRInSTR " with error %s",
2059             (void *)connection,
2060             AWS_BYTE_CURSOR_PRI(task_topic->request.topic),
2061             aws_error_debug_str(aws_last_error()));
2062         goto handle_error;
2063     }
2064 
2065     AWS_LOGF_DEBUG(
2066         AWS_LS_MQTT_CLIENT,
2067         "id=%p: Starting subscribe %" PRIu16 " on topic " PRInSTR,
2068         (void *)connection,
2069         packet_id,
2070         AWS_BYTE_CURSOR_PRI(task_topic->request.topic));
2071 
2072     return packet_id;
2073 
2074 handle_error:
2075 
2076     if (task_topic) {
2077         if (task_topic->filter) {
2078             aws_string_destroy(task_topic->filter);
2079         }
2080         aws_mem_release(connection->allocator, task_topic);
2081     }
2082 
2083     if (task_arg) {
2084         aws_mem_release(connection->allocator, task_arg);
2085     }
2086 
2087     return 0;
2088 }
2089 
2090 /*******************************************************************************
2091  * Subscribe Local
2092  ******************************************************************************/
2093 
2094 /* The lifetime of this struct is from subscribe -> suback */
2095 struct subscribe_local_task_arg {
2096 
2097     struct aws_mqtt_client_connection *connection;
2098 
2099     struct subscribe_task_topic *task_topic;
2100 
2101     aws_mqtt_suback_fn *on_suback;
2102     void *on_suback_ud;
2103 };
2104 
s_subscribe_local_send(uint16_t packet_id,bool is_first_attempt,void * userdata)2105 static enum aws_mqtt_client_request_state s_subscribe_local_send(
2106     uint16_t packet_id,
2107     bool is_first_attempt,
2108     void *userdata) {
2109 
2110     (void)is_first_attempt;
2111 
2112     struct subscribe_local_task_arg *task_arg = userdata;
2113 
2114     AWS_LOGF_TRACE(
2115         AWS_LS_MQTT_CLIENT,
2116         "id=%p: Attempting save of local subscribe %" PRIu16 " (%s)",
2117         (void *)task_arg->connection,
2118         packet_id,
2119         is_first_attempt ? "first attempt" : "redo");
2120 
2121     struct subscribe_task_topic *topic = task_arg->task_topic;
2122     if (aws_mqtt_topic_tree_insert(
2123             &task_arg->connection->thread_data.subscriptions,
2124             topic->filter,
2125             topic->request.qos,
2126             s_on_publish_client_wrapper,
2127             s_task_topic_release,
2128             topic)) {
2129 
2130         return AWS_MQTT_CLIENT_REQUEST_ERROR;
2131     }
2132     aws_ref_count_acquire(&topic->ref_count);
2133 
2134     return AWS_MQTT_CLIENT_REQUEST_COMPLETE;
2135 }
2136 
s_subscribe_local_complete(struct aws_mqtt_client_connection * connection,uint16_t packet_id,int error_code,void * userdata)2137 static void s_subscribe_local_complete(
2138     struct aws_mqtt_client_connection *connection,
2139     uint16_t packet_id,
2140     int error_code,
2141     void *userdata) {
2142 
2143     struct subscribe_local_task_arg *task_arg = userdata;
2144 
2145     AWS_LOGF_DEBUG(
2146         AWS_LS_MQTT_CLIENT,
2147         "id=%p: Local subscribe %" PRIu16 " completed with error code %d",
2148         (void *)connection,
2149         packet_id,
2150         error_code);
2151 
2152     struct subscribe_task_topic *topic = task_arg->task_topic;
2153     if (task_arg->on_suback) {
2154         aws_mqtt_suback_fn *suback = task_arg->on_suback;
2155         suback(connection, packet_id, &topic->request.topic, topic->request.qos, error_code, task_arg->on_suback_ud);
2156     }
2157     s_task_topic_release(topic);
2158 
2159     aws_mem_release(task_arg->connection->allocator, task_arg);
2160 }
2161 
aws_mqtt_client_connection_subscribe_local(struct aws_mqtt_client_connection * connection,const struct aws_byte_cursor * topic_filter,aws_mqtt_client_publish_received_fn * on_publish,void * on_publish_ud,aws_mqtt_userdata_cleanup_fn * on_ud_cleanup,aws_mqtt_suback_fn * on_suback,void * on_suback_ud)2162 uint16_t aws_mqtt_client_connection_subscribe_local(
2163     struct aws_mqtt_client_connection *connection,
2164     const struct aws_byte_cursor *topic_filter,
2165     aws_mqtt_client_publish_received_fn *on_publish,
2166     void *on_publish_ud,
2167     aws_mqtt_userdata_cleanup_fn *on_ud_cleanup,
2168     aws_mqtt_suback_fn *on_suback,
2169     void *on_suback_ud) {
2170 
2171     AWS_PRECONDITION(connection);
2172 
2173     if (!aws_mqtt_is_valid_topic_filter(topic_filter)) {
2174         aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
2175         return 0;
2176     }
2177 
2178     struct subscribe_task_topic *task_topic = NULL;
2179 
2180     struct subscribe_local_task_arg *task_arg =
2181         aws_mem_calloc(connection->allocator, 1, sizeof(struct subscribe_local_task_arg));
2182 
2183     if (!task_arg) {
2184         goto handle_error;
2185     }
2186     AWS_ZERO_STRUCT(*task_arg);
2187 
2188     task_arg->connection = connection;
2189     task_arg->on_suback = on_suback;
2190     task_arg->on_suback_ud = on_suback_ud;
2191     task_topic = aws_mem_calloc(connection->allocator, 1, sizeof(struct subscribe_task_topic));
2192     if (!task_topic) {
2193         goto handle_error;
2194     }
2195     aws_ref_count_init(&task_topic->ref_count, task_topic, (aws_simple_completion_callback *)s_task_topic_clean_up);
2196     task_arg->task_topic = task_topic;
2197 
2198     task_topic->filter = aws_string_new_from_array(connection->allocator, topic_filter->ptr, topic_filter->len);
2199     if (!task_topic->filter) {
2200         goto handle_error;
2201     }
2202 
2203     task_topic->connection = connection;
2204     task_topic->is_local = true;
2205     task_topic->request.topic = aws_byte_cursor_from_string(task_topic->filter);
2206     task_topic->request.on_publish = on_publish;
2207     task_topic->request.on_cleanup = on_ud_cleanup;
2208     task_topic->request.on_publish_ud = on_publish_ud;
2209 
2210     uint16_t packet_id = mqtt_create_request(
2211         task_arg->connection,
2212         s_subscribe_local_send,
2213         task_arg,
2214         &s_subscribe_local_complete,
2215         task_arg,
2216         false /* noRetry */);
2217 
2218     if (packet_id == 0) {
2219         AWS_LOGF_ERROR(
2220             AWS_LS_MQTT_CLIENT,
2221             "id=%p: Failed to start local subscribe on topic " PRInSTR " with error %s",
2222             (void *)connection,
2223             AWS_BYTE_CURSOR_PRI(task_topic->request.topic),
2224             aws_error_debug_str(aws_last_error()));
2225         goto handle_error;
2226     }
2227 
2228     AWS_LOGF_DEBUG(
2229         AWS_LS_MQTT_CLIENT,
2230         "id=%p: Starting local subscribe %" PRIu16 " on topic " PRInSTR,
2231         (void *)connection,
2232         packet_id,
2233         AWS_BYTE_CURSOR_PRI(task_topic->request.topic));
2234     return packet_id;
2235 
2236 handle_error:
2237 
2238     if (task_topic) {
2239         if (task_topic->filter) {
2240             aws_string_destroy(task_topic->filter);
2241         }
2242         aws_mem_release(connection->allocator, task_topic);
2243     }
2244 
2245     if (task_arg) {
2246         aws_mem_release(connection->allocator, task_arg);
2247     }
2248 
2249     return 0;
2250 }
2251 
2252 /*******************************************************************************
2253  * Resubscribe
2254  ******************************************************************************/
2255 
s_reconnect_resub_iterator(const struct aws_byte_cursor * topic,enum aws_mqtt_qos qos,void * user_data)2256 static bool s_reconnect_resub_iterator(const struct aws_byte_cursor *topic, enum aws_mqtt_qos qos, void *user_data) {
2257     struct subscribe_task_arg *task_arg = user_data;
2258 
2259     struct subscribe_task_topic *task_topic =
2260         aws_mem_calloc(task_arg->connection->allocator, 1, sizeof(struct subscribe_task_topic));
2261     struct aws_mqtt_topic_subscription sub;
2262     AWS_ZERO_STRUCT(sub);
2263     sub.topic = *topic;
2264     sub.qos = qos;
2265     task_topic->request = sub;
2266     task_topic->connection = task_arg->connection;
2267 
2268     aws_array_list_push_back(&task_arg->topics, &task_topic);
2269     aws_ref_count_init(&task_topic->ref_count, task_topic, (aws_simple_completion_callback *)s_task_topic_clean_up);
2270     return true;
2271 }
2272 
s_resubscribe_send(uint16_t packet_id,bool is_first_attempt,void * userdata)2273 static enum aws_mqtt_client_request_state s_resubscribe_send(
2274     uint16_t packet_id,
2275     bool is_first_attempt,
2276     void *userdata) {
2277 
2278     struct subscribe_task_arg *task_arg = userdata;
2279     bool initing_packet = task_arg->subscribe.fixed_header.packet_type == 0;
2280     struct aws_io_message *message = NULL;
2281 
2282     size_t sub_count = aws_mqtt_topic_tree_get_sub_count(&task_arg->connection->thread_data.subscriptions);
2283     if (sub_count == 0) {
2284         AWS_LOGF_TRACE(
2285             AWS_LS_MQTT_CLIENT,
2286             "id=%p: Not subscribed to any topics. Resubscribe is unnecessary, no packet will be sent.",
2287             (void *)task_arg->connection);
2288         return AWS_MQTT_CLIENT_REQUEST_COMPLETE;
2289     }
2290     if (aws_array_list_init_dynamic(&task_arg->topics, task_arg->connection->allocator, sub_count, sizeof(void *))) {
2291         goto handle_error;
2292     }
2293     aws_mqtt_topic_tree_iterate(&task_arg->connection->thread_data.subscriptions, s_reconnect_resub_iterator, task_arg);
2294 
2295     AWS_LOGF_TRACE(
2296         AWS_LS_MQTT_CLIENT,
2297         "id=%p: Attempting send of resubscribe %" PRIu16 " (%s)",
2298         (void *)task_arg->connection,
2299         packet_id,
2300         is_first_attempt ? "first attempt" : "resend");
2301 
2302     if (initing_packet) {
2303         /* Init the subscribe packet */
2304         if (aws_mqtt_packet_subscribe_init(&task_arg->subscribe, task_arg->connection->allocator, packet_id)) {
2305             return AWS_MQTT_CLIENT_REQUEST_ERROR;
2306         }
2307 
2308         const size_t num_topics = aws_array_list_length(&task_arg->topics);
2309         if (num_topics <= 0) {
2310             aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
2311             return AWS_MQTT_CLIENT_REQUEST_ERROR;
2312         }
2313 
2314         for (size_t i = 0; i < num_topics; ++i) {
2315 
2316             struct subscribe_task_topic *topic = NULL;
2317             aws_array_list_get_at(&task_arg->topics, &topic, i);
2318             AWS_ASSUME(topic); /* We know we're within bounds */
2319 
2320             if (aws_mqtt_packet_subscribe_add_topic(&task_arg->subscribe, topic->request.topic, topic->request.qos)) {
2321                 goto handle_error;
2322             }
2323         }
2324     }
2325 
2326     message = mqtt_get_message_for_packet(task_arg->connection, &task_arg->subscribe.fixed_header);
2327     if (!message) {
2328 
2329         goto handle_error;
2330     }
2331 
2332     if (aws_mqtt_packet_subscribe_encode(&message->message_data, &task_arg->subscribe)) {
2333 
2334         goto handle_error;
2335     }
2336 
2337     /* This is not necessarily a fatal error; if the send fails, it'll just retry.  Still need to clean up though. */
2338     if (aws_channel_slot_send_message(task_arg->connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
2339         aws_mem_release(message->allocator, message);
2340     }
2341 
2342     return AWS_MQTT_CLIENT_REQUEST_ONGOING;
2343 
2344 handle_error:
2345 
2346     if (message) {
2347         aws_mem_release(message->allocator, message);
2348     }
2349 
2350     return AWS_MQTT_CLIENT_REQUEST_ERROR;
2351 }
2352 
s_resubscribe_complete(struct aws_mqtt_client_connection * connection,uint16_t packet_id,int error_code,void * userdata)2353 static void s_resubscribe_complete(
2354     struct aws_mqtt_client_connection *connection,
2355     uint16_t packet_id,
2356     int error_code,
2357     void *userdata) {
2358 
2359     struct subscribe_task_arg *task_arg = userdata;
2360 
2361     struct subscribe_task_topic *topic = NULL;
2362     aws_array_list_get_at(&task_arg->topics, &topic, 0);
2363     AWS_ASSUME(topic);
2364 
2365     AWS_LOGF_DEBUG(
2366         AWS_LS_MQTT_CLIENT,
2367         "id=%p: Subscribe %" PRIu16 " completed with error_code %d",
2368         (void *)connection,
2369         packet_id,
2370         error_code);
2371 
2372     size_t list_len = aws_array_list_length(&task_arg->topics);
2373     if (task_arg->on_suback.multi) {
2374         /* create a list of aws_mqtt_topic_subscription pointers from topics for the callback */
2375         AWS_VARIABLE_LENGTH_ARRAY(uint8_t, cb_list_buf, list_len * sizeof(void *));
2376         struct aws_array_list cb_list;
2377         aws_array_list_init_static(&cb_list, cb_list_buf, list_len, sizeof(void *));
2378         int err = 0;
2379         for (size_t i = 0; i < list_len; i++) {
2380             err |= aws_array_list_get_at(&task_arg->topics, &topic, i);
2381             struct aws_mqtt_topic_subscription *subscription = &topic->request;
2382             err |= aws_array_list_push_back(&cb_list, &subscription);
2383         }
2384         AWS_ASSUME(!err);
2385         task_arg->on_suback.multi(connection, packet_id, &cb_list, error_code, task_arg->on_suback_ud);
2386         aws_array_list_clean_up(&cb_list);
2387     } else if (task_arg->on_suback.single) {
2388         task_arg->on_suback.single(
2389             connection, packet_id, &topic->request.topic, topic->request.qos, error_code, task_arg->on_suback_ud);
2390     }
2391 
2392     /* We need to cleanup the subscribe_task_topics, since they are not inserted into the topic tree by resubscribe. We
2393      * take the ownership to clean it up */
2394     for (size_t i = 0; i < list_len; i++) {
2395         aws_array_list_get_at(&task_arg->topics, &topic, i);
2396         s_task_topic_release(topic);
2397     }
2398     aws_array_list_clean_up(&task_arg->topics);
2399     aws_mqtt_packet_subscribe_clean_up(&task_arg->subscribe);
2400     aws_mem_release(task_arg->connection->allocator, task_arg);
2401 }
2402 
aws_mqtt_resubscribe_existing_topics(struct aws_mqtt_client_connection * connection,aws_mqtt_suback_multi_fn * on_suback,void * on_suback_ud)2403 uint16_t aws_mqtt_resubscribe_existing_topics(
2404     struct aws_mqtt_client_connection *connection,
2405     aws_mqtt_suback_multi_fn *on_suback,
2406     void *on_suback_ud) {
2407 
2408     struct subscribe_task_arg *task_arg = aws_mem_acquire(connection->allocator, sizeof(struct subscribe_task_arg));
2409     if (!task_arg) {
2410         AWS_LOGF_ERROR(
2411             AWS_LS_MQTT_CLIENT, "id=%p: failed to allocate storage for resubscribe arguments", (void *)connection);
2412         return 0;
2413     }
2414 
2415     AWS_ZERO_STRUCT(*task_arg);
2416     task_arg->connection = connection;
2417     task_arg->on_suback.multi = on_suback;
2418     task_arg->on_suback_ud = on_suback_ud;
2419 
2420     uint16_t packet_id = mqtt_create_request(
2421         task_arg->connection, &s_resubscribe_send, task_arg, &s_resubscribe_complete, task_arg, false /* noRetry */);
2422 
2423     if (packet_id == 0) {
2424         AWS_LOGF_ERROR(
2425             AWS_LS_MQTT_CLIENT,
2426             "id=%p: Failed to send multi-topic resubscribe with error %s",
2427             (void *)connection,
2428             aws_error_name(aws_last_error()));
2429         goto handle_error;
2430     }
2431 
2432     AWS_LOGF_DEBUG(
2433         AWS_LS_MQTT_CLIENT, "id=%p: Sending multi-topic resubscribe %" PRIu16, (void *)connection, packet_id);
2434 
2435     return packet_id;
2436 
2437 handle_error:
2438 
2439     aws_mem_release(connection->allocator, task_arg);
2440 
2441     return 0;
2442 }
2443 
2444 /*******************************************************************************
2445  * Unsubscribe
2446  ******************************************************************************/
2447 
2448 struct unsubscribe_task_arg {
2449     struct aws_mqtt_client_connection *connection;
2450     struct aws_string *filter_string;
2451     struct aws_byte_cursor filter;
2452     bool is_local;
2453     /* Packet to populate */
2454     struct aws_mqtt_packet_unsubscribe unsubscribe;
2455 
2456     /* true if transaction was committed to the topic tree, false requires a retry */
2457     bool tree_updated;
2458 
2459     aws_mqtt_op_complete_fn *on_unsuback;
2460     void *on_unsuback_ud;
2461 
2462     struct request_timeout_wrapper timeout_wrapper;
2463 };
2464 
s_unsubscribe_send(uint16_t packet_id,bool is_first_attempt,void * userdata)2465 static enum aws_mqtt_client_request_state s_unsubscribe_send(
2466     uint16_t packet_id,
2467     bool is_first_attempt,
2468     void *userdata) {
2469 
2470     (void)is_first_attempt;
2471 
2472     struct unsubscribe_task_arg *task_arg = userdata;
2473     struct aws_io_message *message = NULL;
2474 
2475     AWS_LOGF_TRACE(
2476         AWS_LS_MQTT_CLIENT,
2477         "id=%p: Attempting send of unsubscribe %" PRIu16 " %s",
2478         (void *)task_arg->connection,
2479         packet_id,
2480         is_first_attempt ? "first attempt" : "resend");
2481 
2482     static const size_t num_topics = 1;
2483 
2484     AWS_VARIABLE_LENGTH_ARRAY(uint8_t, transaction_buf, num_topics * aws_mqtt_topic_tree_action_size);
2485     struct aws_array_list transaction;
2486     aws_array_list_init_static(&transaction, transaction_buf, num_topics, aws_mqtt_topic_tree_action_size);
2487 
2488     if (!task_arg->tree_updated) {
2489 
2490         struct subscribe_task_topic *topic;
2491         if (aws_mqtt_topic_tree_transaction_remove(
2492                 &task_arg->connection->thread_data.subscriptions, &transaction, &task_arg->filter, (void **)&topic)) {
2493             goto handle_error;
2494         }
2495 
2496         task_arg->is_local = topic ? topic->is_local : false;
2497     }
2498 
2499     if (!task_arg->is_local) {
2500         if (task_arg->unsubscribe.fixed_header.packet_type == 0) {
2501             /* If unsubscribe packet is uninitialized, init it */
2502             if (aws_mqtt_packet_unsubscribe_init(&task_arg->unsubscribe, task_arg->connection->allocator, packet_id)) {
2503                 goto handle_error;
2504             }
2505             if (aws_mqtt_packet_unsubscribe_add_topic(&task_arg->unsubscribe, task_arg->filter)) {
2506                 goto handle_error;
2507             }
2508         }
2509 
2510         message = mqtt_get_message_for_packet(task_arg->connection, &task_arg->unsubscribe.fixed_header);
2511         if (!message) {
2512             goto handle_error;
2513         }
2514 
2515         if (aws_mqtt_packet_unsubscribe_encode(&message->message_data, &task_arg->unsubscribe)) {
2516             goto handle_error;
2517         }
2518 
2519         if (aws_channel_slot_send_message(task_arg->connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
2520             goto handle_error;
2521         }
2522 
2523         /* TODO: timing should start from the message written into the socket, which is aws_io_message->on_completion
2524          * invoked, but there are bugs in the websocket handler (and maybe also the h1 handler?) where we don't properly
2525          * fire fire the on_completion callbacks. */
2526         struct request_timeout_task_arg *timeout_task_arg = s_schedule_timeout_task(task_arg->connection, packet_id);
2527         if (!timeout_task_arg) {
2528             return AWS_MQTT_CLIENT_REQUEST_ERROR;
2529         }
2530 
2531         /*
2532          * Set up mutual references between the operation task args and the timeout task args.  Whoever runs first
2533          * "wins", does its logic, and then breaks the connection between the two.
2534          */
2535         task_arg->timeout_wrapper.timeout_task_arg = timeout_task_arg;
2536         timeout_task_arg->task_arg_wrapper = &task_arg->timeout_wrapper;
2537     }
2538 
2539     if (!task_arg->tree_updated) {
2540         aws_mqtt_topic_tree_transaction_commit(&task_arg->connection->thread_data.subscriptions, &transaction);
2541         task_arg->tree_updated = true;
2542     }
2543 
2544     aws_array_list_clean_up(&transaction);
2545     /* If the subscribe is local-only, don't wait for a SUBACK to come back. */
2546     return task_arg->is_local ? AWS_MQTT_CLIENT_REQUEST_COMPLETE : AWS_MQTT_CLIENT_REQUEST_ONGOING;
2547 
2548 handle_error:
2549 
2550     if (message) {
2551         aws_mem_release(message->allocator, message);
2552     }
2553     if (!task_arg->tree_updated) {
2554         aws_mqtt_topic_tree_transaction_roll_back(&task_arg->connection->thread_data.subscriptions, &transaction);
2555     }
2556 
2557     aws_array_list_clean_up(&transaction);
2558     return AWS_MQTT_CLIENT_REQUEST_ERROR;
2559 }
2560 
s_unsubscribe_complete(struct aws_mqtt_client_connection * connection,uint16_t packet_id,int error_code,void * userdata)2561 static void s_unsubscribe_complete(
2562     struct aws_mqtt_client_connection *connection,
2563     uint16_t packet_id,
2564     int error_code,
2565     void *userdata) {
2566 
2567     struct unsubscribe_task_arg *task_arg = userdata;
2568 
2569     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Unsubscribe %" PRIu16 " complete", (void *)connection, packet_id);
2570 
2571     /*
2572      * If we have a forward pointer to a timeout task, then that means the timeout task has not run yet.  So we should
2573      * follow it and zero out the back pointer to us, because we're going away now.  The timeout task will run later
2574      * and be harmless (even vs. future operations with the same packet id) because it only cancels if it has a back
2575      * pointer.
2576      */
2577     if (task_arg->timeout_wrapper.timeout_task_arg) {
2578         task_arg->timeout_wrapper.timeout_task_arg->task_arg_wrapper = NULL;
2579         task_arg->timeout_wrapper.timeout_task_arg = NULL;
2580     }
2581 
2582     if (task_arg->on_unsuback) {
2583         task_arg->on_unsuback(connection, packet_id, error_code, task_arg->on_unsuback_ud);
2584     }
2585 
2586     aws_string_destroy(task_arg->filter_string);
2587     aws_mqtt_packet_unsubscribe_clean_up(&task_arg->unsubscribe);
2588     aws_mem_release(task_arg->connection->allocator, task_arg);
2589 }
2590 
aws_mqtt_client_connection_unsubscribe(struct aws_mqtt_client_connection * connection,const struct aws_byte_cursor * topic_filter,aws_mqtt_op_complete_fn * on_unsuback,void * on_unsuback_ud)2591 uint16_t aws_mqtt_client_connection_unsubscribe(
2592     struct aws_mqtt_client_connection *connection,
2593     const struct aws_byte_cursor *topic_filter,
2594     aws_mqtt_op_complete_fn *on_unsuback,
2595     void *on_unsuback_ud) {
2596 
2597     AWS_PRECONDITION(connection);
2598 
2599     if (!aws_mqtt_is_valid_topic_filter(topic_filter)) {
2600         aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
2601         return 0;
2602     }
2603 
2604     struct unsubscribe_task_arg *task_arg =
2605         aws_mem_calloc(connection->allocator, 1, sizeof(struct unsubscribe_task_arg));
2606     if (!task_arg) {
2607         return 0;
2608     }
2609 
2610     task_arg->connection = connection;
2611     task_arg->filter_string = aws_string_new_from_array(connection->allocator, topic_filter->ptr, topic_filter->len);
2612     task_arg->filter = aws_byte_cursor_from_string(task_arg->filter_string);
2613     task_arg->on_unsuback = on_unsuback;
2614     task_arg->on_unsuback_ud = on_unsuback_ud;
2615 
2616     uint16_t packet_id = mqtt_create_request(
2617         connection, &s_unsubscribe_send, task_arg, s_unsubscribe_complete, task_arg, false /* noRetry */);
2618     if (packet_id == 0) {
2619         AWS_LOGF_DEBUG(
2620             AWS_LS_MQTT_CLIENT,
2621             "id=%p: Failed to start unsubscribe, with error %s",
2622             (void *)connection,
2623             aws_error_debug_str(aws_last_error()));
2624         goto handle_error;
2625     }
2626 
2627     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Starting unsubscribe %" PRIu16, (void *)connection, packet_id);
2628 
2629     return packet_id;
2630 
2631 handle_error:
2632 
2633     aws_string_destroy(task_arg->filter_string);
2634     aws_mem_release(connection->allocator, task_arg);
2635 
2636     return 0;
2637 }
2638 
2639 /*******************************************************************************
2640  * Publish
2641  ******************************************************************************/
2642 
2643 struct publish_task_arg {
2644     struct aws_mqtt_client_connection *connection;
2645     struct aws_string *topic_string;
2646     struct aws_byte_cursor topic;
2647     enum aws_mqtt_qos qos;
2648     bool retain;
2649     struct aws_byte_cursor payload;
2650     struct aws_byte_buf payload_buf;
2651 
2652     /* Packet to populate */
2653     struct aws_mqtt_packet_publish publish;
2654 
2655     aws_mqtt_op_complete_fn *on_complete;
2656     void *userdata;
2657 
2658     struct request_timeout_wrapper timeout_wrapper;
2659 };
2660 
2661 /* should only be called by tests */
s_get_stuff_from_outstanding_requests_table(struct aws_mqtt_client_connection * connection,uint16_t packet_id,struct aws_allocator * allocator,struct aws_byte_buf * result_buf,struct aws_string ** result_string)2662 static int s_get_stuff_from_outstanding_requests_table(
2663     struct aws_mqtt_client_connection *connection,
2664     uint16_t packet_id,
2665     struct aws_allocator *allocator,
2666     struct aws_byte_buf *result_buf,
2667     struct aws_string **result_string) {
2668 
2669     int err = AWS_OP_SUCCESS;
2670 
2671     aws_mutex_lock(&connection->synced_data.lock);
2672     struct aws_hash_element *elem = NULL;
2673     aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &packet_id, &elem);
2674     if (elem) {
2675         struct aws_mqtt_request *request = elem->value;
2676         struct publish_task_arg *pub = (struct publish_task_arg *)request->send_request_ud;
2677         if (result_buf != NULL) {
2678             if (aws_byte_buf_init_copy(result_buf, allocator, &pub->payload_buf)) {
2679                 err = AWS_OP_ERR;
2680             }
2681         } else if (result_string != NULL) {
2682             *result_string = aws_string_new_from_string(allocator, pub->topic_string);
2683             if (*result_string == NULL) {
2684                 err = AWS_OP_ERR;
2685             }
2686         }
2687     } else {
2688         /* So lovely that this error is defined, but hashtable never actually raises it */
2689         err = aws_raise_error(AWS_ERROR_HASHTBL_ITEM_NOT_FOUND);
2690     }
2691     aws_mutex_unlock(&connection->synced_data.lock);
2692 
2693     return err;
2694 }
2695 
2696 /* should only be called by tests */
aws_mqtt_client_get_payload_for_outstanding_publish_packet(struct aws_mqtt_client_connection * connection,uint16_t packet_id,struct aws_allocator * allocator,struct aws_byte_buf * result)2697 int aws_mqtt_client_get_payload_for_outstanding_publish_packet(
2698     struct aws_mqtt_client_connection *connection,
2699     uint16_t packet_id,
2700     struct aws_allocator *allocator,
2701     struct aws_byte_buf *result) {
2702 
2703     AWS_ZERO_STRUCT(*result);
2704     return s_get_stuff_from_outstanding_requests_table(connection, packet_id, allocator, result, NULL);
2705 }
2706 
2707 /* should only be called by tests */
aws_mqtt_client_get_topic_for_outstanding_publish_packet(struct aws_mqtt_client_connection * connection,uint16_t packet_id,struct aws_allocator * allocator,struct aws_string ** result)2708 int aws_mqtt_client_get_topic_for_outstanding_publish_packet(
2709     struct aws_mqtt_client_connection *connection,
2710     uint16_t packet_id,
2711     struct aws_allocator *allocator,
2712     struct aws_string **result) {
2713 
2714     *result = NULL;
2715     return s_get_stuff_from_outstanding_requests_table(connection, packet_id, allocator, NULL, result);
2716 }
2717 
s_publish_send(uint16_t packet_id,bool is_first_attempt,void * userdata)2718 static enum aws_mqtt_client_request_state s_publish_send(uint16_t packet_id, bool is_first_attempt, void *userdata) {
2719     struct publish_task_arg *task_arg = userdata;
2720     struct aws_mqtt_client_connection *connection = task_arg->connection;
2721 
2722     AWS_LOGF_TRACE(
2723         AWS_LS_MQTT_CLIENT,
2724         "id=%p: Attempting send of publish %" PRIu16 " %s",
2725         (void *)task_arg->connection,
2726         packet_id,
2727         is_first_attempt ? "first attempt" : "resend");
2728 
2729     bool is_qos_0 = task_arg->qos == AWS_MQTT_QOS_AT_MOST_ONCE;
2730     if (is_qos_0) {
2731         packet_id = 0;
2732     }
2733 
2734     if (is_first_attempt) {
2735         if (aws_mqtt_packet_publish_init(
2736                 &task_arg->publish,
2737                 task_arg->retain,
2738                 task_arg->qos,
2739                 !is_first_attempt,
2740                 task_arg->topic,
2741                 packet_id,
2742                 task_arg->payload)) {
2743 
2744             return AWS_MQTT_CLIENT_REQUEST_ERROR;
2745         }
2746     }
2747 
2748     struct aws_io_message *message = mqtt_get_message_for_packet(task_arg->connection, &task_arg->publish.fixed_header);
2749     if (!message) {
2750         return AWS_MQTT_CLIENT_REQUEST_ERROR;
2751     }
2752 
2753     /* Encode the headers, and everything but the payload */
2754     if (aws_mqtt_packet_publish_encode_headers(&message->message_data, &task_arg->publish)) {
2755         return AWS_MQTT_CLIENT_REQUEST_ERROR;
2756     }
2757 
2758     struct aws_byte_cursor payload_cur = task_arg->payload;
2759     {
2760     write_payload_chunk:
2761         (void)NULL;
2762 
2763         const size_t left_in_message = message->message_data.capacity - message->message_data.len;
2764         const size_t to_write = payload_cur.len < left_in_message ? payload_cur.len : left_in_message;
2765 
2766         if (to_write) {
2767             /* Write this chunk */
2768             struct aws_byte_cursor to_write_cur = aws_byte_cursor_advance(&payload_cur, to_write);
2769             AWS_ASSERT(to_write_cur.ptr); /* to_write is guaranteed to be inside the bounds of payload_cur */
2770             if (!aws_byte_buf_write_from_whole_cursor(&message->message_data, to_write_cur)) {
2771 
2772                 aws_mem_release(message->allocator, message);
2773                 return AWS_MQTT_CLIENT_REQUEST_ERROR;
2774             }
2775         }
2776 
2777         if (aws_channel_slot_send_message(task_arg->connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
2778             aws_mem_release(message->allocator, message);
2779             /* If it's QoS 0, telling user that the message haven't been sent, else, the message will be resent once the
2780              * connection is back */
2781             return is_qos_0 ? AWS_MQTT_CLIENT_REQUEST_ERROR : AWS_MQTT_CLIENT_REQUEST_ONGOING;
2782         }
2783 
2784         /* If there's still payload left, get a new message and start again. */
2785         if (payload_cur.len) {
2786             message = mqtt_get_message_for_packet(task_arg->connection, &task_arg->publish.fixed_header);
2787             goto write_payload_chunk;
2788         }
2789     }
2790     if (!is_qos_0 && connection->operation_timeout_ns != UINT64_MAX) {
2791         /* TODO: timing should start from the message written into the socket, which is aws_io_message->on_completion
2792          * invoked, but there are bugs in the websocket handler (and maybe also the h1 handler?) where we don't properly
2793          * fire fire the on_completion callbacks. */
2794         struct request_timeout_task_arg *timeout_task_arg = s_schedule_timeout_task(connection, packet_id);
2795         if (!timeout_task_arg) {
2796             return AWS_MQTT_CLIENT_REQUEST_ERROR;
2797         }
2798 
2799         /*
2800          * Set up mutual references between the operation task args and the timeout task args.  Whoever runs first
2801          * "wins", does its logic, and then breaks the connection between the two.
2802          */
2803         task_arg->timeout_wrapper.timeout_task_arg = timeout_task_arg;
2804         timeout_task_arg->task_arg_wrapper = &task_arg->timeout_wrapper;
2805     }
2806 
2807     /* If QoS == 0, there will be no ack, so consider the request done now. */
2808     return is_qos_0 ? AWS_MQTT_CLIENT_REQUEST_COMPLETE : AWS_MQTT_CLIENT_REQUEST_ONGOING;
2809 }
2810 
s_publish_complete(struct aws_mqtt_client_connection * connection,uint16_t packet_id,int error_code,void * userdata)2811 static void s_publish_complete(
2812     struct aws_mqtt_client_connection *connection,
2813     uint16_t packet_id,
2814     int error_code,
2815     void *userdata) {
2816     struct publish_task_arg *task_arg = userdata;
2817 
2818     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Publish %" PRIu16 " complete", (void *)connection, packet_id);
2819 
2820     if (task_arg->on_complete) {
2821         task_arg->on_complete(connection, packet_id, error_code, task_arg->userdata);
2822     }
2823 
2824     /*
2825      * If we have a forward pointer to a timeout task, then that means the timeout task has not run yet.  So we should
2826      * follow it and zero out the back pointer to us, because we're going away now.  The timeout task will run later
2827      * and be harmless (even vs. future operations with the same packet id) because it only cancels if it has a back
2828      * pointer.
2829      */
2830     if (task_arg->timeout_wrapper.timeout_task_arg != NULL) {
2831         task_arg->timeout_wrapper.timeout_task_arg->task_arg_wrapper = NULL;
2832         task_arg->timeout_wrapper.timeout_task_arg = NULL;
2833     }
2834 
2835     aws_byte_buf_clean_up(&task_arg->payload_buf);
2836     aws_string_destroy(task_arg->topic_string);
2837     aws_mem_release(connection->allocator, task_arg);
2838 }
2839 
aws_mqtt_client_connection_publish(struct aws_mqtt_client_connection * connection,const struct aws_byte_cursor * topic,enum aws_mqtt_qos qos,bool retain,const struct aws_byte_cursor * payload,aws_mqtt_op_complete_fn * on_complete,void * userdata)2840 uint16_t aws_mqtt_client_connection_publish(
2841     struct aws_mqtt_client_connection *connection,
2842     const struct aws_byte_cursor *topic,
2843     enum aws_mqtt_qos qos,
2844     bool retain,
2845     const struct aws_byte_cursor *payload,
2846     aws_mqtt_op_complete_fn *on_complete,
2847     void *userdata) {
2848 
2849     AWS_PRECONDITION(connection);
2850 
2851     if (!aws_mqtt_is_valid_topic(topic)) {
2852         aws_raise_error(AWS_ERROR_MQTT_INVALID_TOPIC);
2853         return 0;
2854     }
2855 
2856     struct publish_task_arg *arg = aws_mem_calloc(connection->allocator, 1, sizeof(struct publish_task_arg));
2857     if (!arg) {
2858         return 0;
2859     }
2860 
2861     arg->connection = connection;
2862     arg->topic_string = aws_string_new_from_array(connection->allocator, topic->ptr, topic->len);
2863     arg->topic = aws_byte_cursor_from_string(arg->topic_string);
2864     arg->qos = qos;
2865     arg->retain = retain;
2866     if (aws_byte_buf_init_copy_from_cursor(&arg->payload_buf, connection->allocator, *payload)) {
2867         goto handle_error;
2868     }
2869     arg->payload = aws_byte_cursor_from_buf(&arg->payload_buf);
2870     arg->on_complete = on_complete;
2871     arg->userdata = userdata;
2872 
2873     bool retry = qos == AWS_MQTT_QOS_AT_MOST_ONCE;
2874     uint16_t packet_id = mqtt_create_request(connection, &s_publish_send, arg, &s_publish_complete, arg, retry);
2875 
2876     if (packet_id == 0) {
2877         /* bummer, we failed to make a new request */
2878         AWS_LOGF_ERROR(
2879             AWS_LS_MQTT_CLIENT,
2880             "id=%p: Failed starting publish to topic " PRInSTR ",error %d (%s)",
2881             (void *)connection,
2882             AWS_BYTE_CURSOR_PRI(*topic),
2883             aws_last_error(),
2884             aws_error_name(aws_last_error()));
2885         goto handle_error;
2886     }
2887 
2888     AWS_LOGF_DEBUG(
2889         AWS_LS_MQTT_CLIENT,
2890         "id=%p: Starting publish %" PRIu16 " to topic " PRInSTR,
2891         (void *)connection,
2892         packet_id,
2893         AWS_BYTE_CURSOR_PRI(*topic));
2894     return packet_id;
2895 
2896 handle_error:
2897 
2898     /* we know arg is valid, topic_string may or may not be valid */
2899     if (arg->topic_string) {
2900         aws_string_destroy(arg->topic_string);
2901     }
2902 
2903     aws_byte_buf_clean_up(&arg->payload_buf);
2904 
2905     aws_mem_release(connection->allocator, arg);
2906 
2907     return 0;
2908 }
2909 
2910 /*******************************************************************************
2911  * Ping
2912  ******************************************************************************/
2913 
s_pingresp_received_timeout(struct aws_channel_task * channel_task,void * arg,enum aws_task_status status)2914 static void s_pingresp_received_timeout(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
2915     struct aws_mqtt_client_connection *connection = arg;
2916 
2917     if (status == AWS_TASK_STATUS_RUN_READY) {
2918         /* Check that a pingresp has been received since pingreq was sent */
2919         if (connection->thread_data.waiting_on_ping_response) {
2920             connection->thread_data.waiting_on_ping_response = false;
2921             /* It's been too long since the last ping, close the connection */
2922             AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: ping timeout detected", (void *)connection);
2923             aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_TIMEOUT);
2924         }
2925     }
2926 
2927     aws_mem_release(connection->allocator, channel_task);
2928 }
2929 
s_pingreq_send(uint16_t packet_id,bool is_first_attempt,void * userdata)2930 static enum aws_mqtt_client_request_state s_pingreq_send(uint16_t packet_id, bool is_first_attempt, void *userdata) {
2931     (void)packet_id;
2932     (void)is_first_attempt;
2933     AWS_PRECONDITION(is_first_attempt);
2934 
2935     struct aws_mqtt_client_connection *connection = userdata;
2936 
2937     AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: pingreq send", (void *)connection);
2938     struct aws_mqtt_packet_connection pingreq;
2939     aws_mqtt_packet_pingreq_init(&pingreq);
2940 
2941     struct aws_io_message *message = mqtt_get_message_for_packet(connection, &pingreq.fixed_header);
2942     if (!message) {
2943         return AWS_MQTT_CLIENT_REQUEST_ERROR;
2944     }
2945 
2946     if (aws_mqtt_packet_connection_encode(&message->message_data, &pingreq)) {
2947         aws_mem_release(message->allocator, message);
2948         return AWS_MQTT_CLIENT_REQUEST_ERROR;
2949     }
2950 
2951     if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
2952         aws_mem_release(message->allocator, message);
2953         return AWS_MQTT_CLIENT_REQUEST_ERROR;
2954     }
2955 
2956     /* Mark down that now is when the last pingreq was sent */
2957     connection->thread_data.waiting_on_ping_response = true;
2958 
2959     struct aws_channel_task *ping_timeout_task =
2960         aws_mem_calloc(connection->allocator, 1, sizeof(struct aws_channel_task));
2961     if (!ping_timeout_task) {
2962         /* allocation failed, no log, just return error. */
2963         goto error;
2964     }
2965     aws_channel_task_init(ping_timeout_task, s_pingresp_received_timeout, connection, "mqtt_pingresp_timeout");
2966     uint64_t now = 0;
2967     if (aws_channel_current_clock_time(connection->slot->channel, &now)) {
2968         goto error;
2969     }
2970     now += connection->ping_timeout_ns;
2971     aws_channel_schedule_task_future(connection->slot->channel, ping_timeout_task, now);
2972     return AWS_MQTT_CLIENT_REQUEST_COMPLETE;
2973 
2974 error:
2975     return AWS_MQTT_CLIENT_REQUEST_ERROR;
2976 }
2977 
aws_mqtt_client_connection_ping(struct aws_mqtt_client_connection * connection)2978 int aws_mqtt_client_connection_ping(struct aws_mqtt_client_connection *connection) {
2979 
2980     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Starting ping", (void *)connection);
2981 
2982     uint16_t packet_id = mqtt_create_request(connection, &s_pingreq_send, connection, NULL, NULL, true /* noRetry */);
2983 
2984     AWS_LOGF_DEBUG(AWS_LS_MQTT_CLIENT, "id=%p: Starting ping with packet id %" PRIu16, (void *)connection, packet_id);
2985 
2986     return (packet_id > 0) ? AWS_OP_SUCCESS : AWS_OP_ERR;
2987 }
2988