1 /**
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0.
4 */
5
6 #include <aws/mqtt/private/client_impl.h>
7
8 #include <aws/mqtt/private/packets.h>
9 #include <aws/mqtt/private/topic_tree.h>
10
11 #include <aws/io/logging.h>
12
13 #include <aws/common/clock.h>
14 #include <aws/common/math.h>
15 #include <aws/common/task_scheduler.h>
16
17 #include <inttypes.h>
18
19 #ifdef _MSC_VER
20 # pragma warning(disable : 4204)
21 #endif
22
23 /*******************************************************************************
24 * Packet State Machine
25 ******************************************************************************/
26
27 typedef int(packet_handler_fn)(struct aws_mqtt_client_connection *connection, struct aws_byte_cursor message_cursor);
28
s_packet_handler_default(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)29 static int s_packet_handler_default(
30 struct aws_mqtt_client_connection *connection,
31 struct aws_byte_cursor message_cursor) {
32 (void)connection;
33 (void)message_cursor;
34
35 AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Unhandled packet type received", (void *)connection);
36 return aws_raise_error(AWS_ERROR_MQTT_INVALID_PACKET_TYPE);
37 }
38
39 static void s_on_time_to_ping(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status);
s_schedule_ping(struct aws_mqtt_client_connection * connection)40 static void s_schedule_ping(struct aws_mqtt_client_connection *connection) {
41 aws_channel_task_init(&connection->ping_task, s_on_time_to_ping, connection, "mqtt_ping");
42
43 uint64_t now = 0;
44 aws_channel_current_clock_time(connection->slot->channel, &now);
45 AWS_LOGF_TRACE(
46 AWS_LS_MQTT_CLIENT, "id=%p: Scheduling PING. current timestamp is %" PRIu64, (void *)connection, now);
47
48 uint64_t schedule_time =
49 now + aws_timestamp_convert(connection->keep_alive_time_secs, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL);
50
51 AWS_LOGF_TRACE(
52 AWS_LS_MQTT_CLIENT,
53 "id=%p: The next ping will be run at timestamp %" PRIu64,
54 (void *)connection,
55 schedule_time);
56 aws_channel_schedule_task_future(connection->slot->channel, &connection->ping_task, schedule_time);
57 }
58
s_on_time_to_ping(struct aws_channel_task * channel_task,void * arg,enum aws_task_status status)59 static void s_on_time_to_ping(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
60 (void)channel_task;
61
62 if (status == AWS_TASK_STATUS_RUN_READY) {
63 struct aws_mqtt_client_connection *connection = arg;
64 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Sending PING", (void *)connection);
65 aws_mqtt_client_connection_ping(connection);
66 s_schedule_ping(connection);
67 }
68 }
s_packet_handler_connack(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)69 static int s_packet_handler_connack(
70 struct aws_mqtt_client_connection *connection,
71 struct aws_byte_cursor message_cursor) {
72
73 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: CONNACK received", (void *)connection);
74
75 struct aws_mqtt_packet_connack connack;
76 if (aws_mqtt_packet_connack_decode(&message_cursor, &connack)) {
77 AWS_LOGF_ERROR(
78 AWS_LS_MQTT_CLIENT, "id=%p: error %d parsing CONNACK packet", (void *)connection, aws_last_error());
79
80 return AWS_OP_ERR;
81 }
82 bool was_reconnecting;
83 struct aws_linked_list requests;
84 aws_linked_list_init(&requests);
85 { /* BEGIN CRITICAL SECTION */
86 mqtt_connection_lock_synced_data(connection);
87 /* User requested disconnect, don't do anything */
88 if (connection->synced_data.state >= AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
89 mqtt_connection_unlock_synced_data(connection);
90 AWS_LOGF_TRACE(
91 AWS_LS_MQTT_CLIENT, "id=%p: User has requested disconnect, dropping connection", (void *)connection);
92 return AWS_OP_SUCCESS;
93 }
94
95 was_reconnecting = connection->synced_data.state == AWS_MQTT_CLIENT_STATE_RECONNECTING;
96 if (connack.connect_return_code == AWS_MQTT_CONNECT_ACCEPTED) {
97 AWS_LOGF_DEBUG(
98 AWS_LS_MQTT_CLIENT,
99 "id=%p: connection was accepted, switch state from %d to CONNECTED.",
100 (void *)connection,
101 (int)connection->synced_data.state);
102 /* Don't change the state if it's not ACCEPTED by broker */
103 mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_CONNECTED);
104 aws_linked_list_swap_contents(&connection->synced_data.pending_requests_list, &requests);
105 }
106 mqtt_connection_unlock_synced_data(connection);
107 } /* END CRITICAL SECTION */
108 connection->connection_count++;
109
110 /* Reset the current timeout timer */
111 connection->reconnect_timeouts.current = connection->reconnect_timeouts.min;
112
113 if (connack.connect_return_code == AWS_MQTT_CONNECT_ACCEPTED) {
114 /* If successfully connected, schedule all pending tasks */
115 AWS_LOGF_TRACE(
116 AWS_LS_MQTT_CLIENT, "id=%p: connection was accepted processing offline requests.", (void *)connection);
117
118 if (!aws_linked_list_empty(&requests)) {
119
120 struct aws_linked_list_node *current = aws_linked_list_front(&requests);
121 const struct aws_linked_list_node *end = aws_linked_list_end(&requests);
122
123 do {
124 struct aws_mqtt_request *request = AWS_CONTAINER_OF(current, struct aws_mqtt_request, list_node);
125 AWS_LOGF_TRACE(
126 AWS_LS_MQTT_CLIENT,
127 "id=%p: processing offline request %" PRIu16,
128 (void *)connection,
129 request->packet_id);
130 aws_channel_schedule_task_now(connection->slot->channel, &request->outgoing_task);
131 current = current->next;
132 } while (current != end);
133 }
134 } else {
135 AWS_LOGF_ERROR(
136 AWS_LS_MQTT_CLIENT,
137 "id=%p: invalid connect return code %d, disconnecting",
138 (void *)connection,
139 connack.connect_return_code);
140 /* If error code returned, disconnect, on_completed will be invoked from shutdown process */
141 aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_PROTOCOL_ERROR);
142
143 return AWS_OP_SUCCESS;
144 }
145
146 /* It is possible for a connection to complete, and a hangup to occur before the
147 * CONNECT/CONNACK cycle completes. In that case, we must deliver on_connection_complete
148 * on the first successful CONNACK or user code will never think it's connected */
149 if (was_reconnecting && connection->connection_count > 1) {
150
151 AWS_LOGF_TRACE(
152 AWS_LS_MQTT_CLIENT,
153 "id=%p: connection is a resumed connection, invoking on_resumed callback",
154 (void *)connection);
155
156 MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_resumed, connack.connect_return_code, connack.session_present);
157 } else {
158
159 aws_create_reconnect_task(connection);
160
161 AWS_LOGF_TRACE(
162 AWS_LS_MQTT_CLIENT,
163 "id=%p: connection is a new connection, invoking on_connection_complete callback",
164 (void *)connection);
165 MQTT_CLIENT_CALL_CALLBACK_ARGS(
166 connection, on_connection_complete, AWS_OP_SUCCESS, connack.connect_return_code, connack.session_present);
167 }
168
169 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: connection callback completed", (void *)connection);
170
171 s_schedule_ping(connection);
172 return AWS_OP_SUCCESS;
173 }
174
s_packet_handler_publish(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)175 static int s_packet_handler_publish(
176 struct aws_mqtt_client_connection *connection,
177 struct aws_byte_cursor message_cursor) {
178
179 /* TODO: need to handle the QoS 2 message to avoid processing the message a second time */
180 struct aws_mqtt_packet_publish publish;
181 if (aws_mqtt_packet_publish_decode(&message_cursor, &publish)) {
182 return AWS_OP_ERR;
183 }
184
185 aws_mqtt_topic_tree_publish(&connection->thread_data.subscriptions, &publish);
186
187 bool dup = aws_mqtt_packet_publish_get_dup(&publish);
188 enum aws_mqtt_qos qos = aws_mqtt_packet_publish_get_qos(&publish);
189 bool retain = aws_mqtt_packet_publish_get_retain(&publish);
190
191 MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_any_publish, &publish.topic_name, &publish.payload, dup, qos, retain);
192
193 AWS_LOGF_TRACE(
194 AWS_LS_MQTT_CLIENT,
195 "id=%p: publish received with msg id=%" PRIu16 " dup=%d qos=%d retain=%d payload-size=%zu topic=" PRInSTR,
196 (void *)connection,
197 publish.packet_identifier,
198 dup,
199 qos,
200 retain,
201 publish.payload.len,
202 AWS_BYTE_CURSOR_PRI(publish.topic_name));
203 struct aws_mqtt_packet_ack puback;
204 AWS_ZERO_STRUCT(puback);
205
206 /* Switch on QoS flags (bits 1 & 2) */
207 switch (qos) {
208 case AWS_MQTT_QOS_AT_MOST_ONCE:
209 AWS_LOGF_TRACE(
210 AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 0, not sending puback", (void *)connection);
211 /* No more communication necessary */
212 break;
213 case AWS_MQTT_QOS_AT_LEAST_ONCE:
214 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 1, sending puback", (void *)connection);
215 aws_mqtt_packet_puback_init(&puback, publish.packet_identifier);
216 break;
217 case AWS_MQTT_QOS_EXACTLY_ONCE:
218 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 2, sending pubrec", (void *)connection);
219 aws_mqtt_packet_pubrec_init(&puback, publish.packet_identifier);
220 break;
221 default:
222 /* Impossible to hit this branch. QoS value is checked when decoding */
223 AWS_FATAL_ASSERT(0);
224 break;
225 }
226
227 if (puback.packet_identifier) {
228 struct aws_io_message *message = mqtt_get_message_for_packet(connection, &puback.fixed_header);
229 if (!message) {
230 return AWS_OP_ERR;
231 }
232
233 if (aws_mqtt_packet_ack_encode(&message->message_data, &puback)) {
234 aws_mem_release(message->allocator, message);
235 return AWS_OP_ERR;
236 }
237
238 if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
239 aws_mem_release(message->allocator, message);
240 return AWS_OP_ERR;
241 }
242 }
243
244 return AWS_OP_SUCCESS;
245 }
246
s_packet_handler_ack(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)247 static int s_packet_handler_ack(struct aws_mqtt_client_connection *connection, struct aws_byte_cursor message_cursor) {
248 struct aws_mqtt_packet_ack ack;
249 if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
250 return AWS_OP_ERR;
251 }
252
253 AWS_LOGF_TRACE(
254 AWS_LS_MQTT_CLIENT, "id=%p: received ack for message id %" PRIu16, (void *)connection, ack.packet_identifier);
255
256 mqtt_request_complete(connection, AWS_ERROR_SUCCESS, ack.packet_identifier);
257
258 return AWS_OP_SUCCESS;
259 }
260
s_packet_handler_suback(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)261 static int s_packet_handler_suback(
262 struct aws_mqtt_client_connection *connection,
263 struct aws_byte_cursor message_cursor) {
264 struct aws_mqtt_packet_suback suback;
265 if (aws_mqtt_packet_suback_init(&suback, connection->allocator, 0 /* fake packet_id */)) {
266 return AWS_OP_ERR;
267 }
268
269 if (aws_mqtt_packet_suback_decode(&message_cursor, &suback)) {
270 goto error;
271 }
272
273 AWS_LOGF_TRACE(
274 AWS_LS_MQTT_CLIENT,
275 "id=%p: received suback for message id %" PRIu16,
276 (void *)connection,
277 suback.packet_identifier);
278
279 struct aws_mqtt_request *request = NULL;
280
281 { /* BEGIN CRITICAL SECTION */
282 mqtt_connection_lock_synced_data(connection);
283 struct aws_hash_element *elem = NULL;
284 aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &suback.packet_identifier, &elem);
285 if (elem != NULL) {
286 request = elem->value;
287 }
288 mqtt_connection_unlock_synced_data(connection);
289 } /* END CRITICAL SECTION */
290
291 if (request == NULL) {
292 /* no corresponding request found */
293 goto done;
294 }
295
296 struct subscribe_task_arg *task_arg = request->on_complete_ud;
297 size_t request_topics_len = aws_array_list_length(&task_arg->topics);
298 size_t suback_return_code_len = aws_array_list_length(&suback.return_codes);
299 if (request_topics_len != suback_return_code_len) {
300 goto error;
301 }
302 size_t num_filters = aws_array_list_length(&suback.return_codes);
303 for (size_t i = 0; i < num_filters; ++i) {
304
305 uint8_t return_code = 0;
306 struct subscribe_task_topic *topic = NULL;
307 aws_array_list_get_at(&suback.return_codes, (void *)&return_code, i);
308 aws_array_list_get_at(&task_arg->topics, &topic, i);
309 topic->request.qos = return_code;
310 }
311
312 done:
313 mqtt_request_complete(connection, AWS_ERROR_SUCCESS, suback.packet_identifier);
314 aws_mqtt_packet_suback_clean_up(&suback);
315 return AWS_OP_SUCCESS;
316 error:
317 aws_mqtt_packet_suback_clean_up(&suback);
318 return AWS_OP_ERR;
319 }
320
s_packet_handler_pubrec(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)321 static int s_packet_handler_pubrec(
322 struct aws_mqtt_client_connection *connection,
323 struct aws_byte_cursor message_cursor) {
324
325 struct aws_mqtt_packet_ack ack;
326 if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
327 return AWS_OP_ERR;
328 }
329
330 /* TODO: When sending PUBLISH with QoS 2, we should be storing the data until this packet is received, at which
331 * point we may discard it. */
332
333 /* Send PUBREL */
334 aws_mqtt_packet_pubrel_init(&ack, ack.packet_identifier);
335 struct aws_io_message *message = mqtt_get_message_for_packet(connection, &ack.fixed_header);
336 if (!message) {
337 return AWS_OP_ERR;
338 }
339
340 if (aws_mqtt_packet_ack_encode(&message->message_data, &ack)) {
341 goto on_error;
342 }
343
344 if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
345 goto on_error;
346 }
347
348 return AWS_OP_SUCCESS;
349
350 on_error:
351
352 if (message) {
353 aws_mem_release(message->allocator, message);
354 }
355
356 return AWS_OP_ERR;
357 }
358
s_packet_handler_pubrel(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)359 static int s_packet_handler_pubrel(
360 struct aws_mqtt_client_connection *connection,
361 struct aws_byte_cursor message_cursor) {
362
363 struct aws_mqtt_packet_ack ack;
364 if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
365 return AWS_OP_ERR;
366 }
367
368 /* Send PUBCOMP */
369 aws_mqtt_packet_pubcomp_init(&ack, ack.packet_identifier);
370 struct aws_io_message *message = mqtt_get_message_for_packet(connection, &ack.fixed_header);
371 if (!message) {
372 return AWS_OP_ERR;
373 }
374
375 if (aws_mqtt_packet_ack_encode(&message->message_data, &ack)) {
376 goto on_error;
377 }
378
379 if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
380 goto on_error;
381 }
382
383 return AWS_OP_SUCCESS;
384
385 on_error:
386
387 if (message) {
388 aws_mem_release(message->allocator, message);
389 }
390
391 return AWS_OP_ERR;
392 }
393
s_packet_handler_pingresp(struct aws_mqtt_client_connection * connection,struct aws_byte_cursor message_cursor)394 static int s_packet_handler_pingresp(
395 struct aws_mqtt_client_connection *connection,
396 struct aws_byte_cursor message_cursor) {
397
398 (void)message_cursor;
399
400 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: PINGRESP received", (void *)connection);
401
402 connection->thread_data.waiting_on_ping_response = false;
403
404 return AWS_OP_SUCCESS;
405 }
406
407 /* Bake up a big ol' function table just like Gramma used to make */
408 static packet_handler_fn *s_packet_handlers[] = {
409 [AWS_MQTT_PACKET_CONNECT] = &s_packet_handler_default,
410 [AWS_MQTT_PACKET_CONNACK] = &s_packet_handler_connack,
411 [AWS_MQTT_PACKET_PUBLISH] = &s_packet_handler_publish,
412 [AWS_MQTT_PACKET_PUBACK] = &s_packet_handler_ack,
413 [AWS_MQTT_PACKET_PUBREC] = &s_packet_handler_pubrec,
414 [AWS_MQTT_PACKET_PUBREL] = &s_packet_handler_pubrel,
415 [AWS_MQTT_PACKET_PUBCOMP] = &s_packet_handler_ack,
416 [AWS_MQTT_PACKET_SUBSCRIBE] = &s_packet_handler_default,
417 [AWS_MQTT_PACKET_SUBACK] = &s_packet_handler_suback,
418 [AWS_MQTT_PACKET_UNSUBSCRIBE] = &s_packet_handler_default,
419 [AWS_MQTT_PACKET_UNSUBACK] = &s_packet_handler_ack,
420 [AWS_MQTT_PACKET_PINGREQ] = &s_packet_handler_default,
421 [AWS_MQTT_PACKET_PINGRESP] = &s_packet_handler_pingresp,
422 [AWS_MQTT_PACKET_DISCONNECT] = &s_packet_handler_default,
423 };
424
425 /*******************************************************************************
426 * Channel Handler
427 ******************************************************************************/
428
s_process_mqtt_packet(struct aws_mqtt_client_connection * connection,enum aws_mqtt_packet_type packet_type,struct aws_byte_cursor packet)429 static int s_process_mqtt_packet(
430 struct aws_mqtt_client_connection *connection,
431 enum aws_mqtt_packet_type packet_type,
432 struct aws_byte_cursor packet) {
433 { /* BEGIN CRITICAL SECTION */
434 mqtt_connection_lock_synced_data(connection);
435 /* [MQTT-3.2.0-1] The first packet sent from the Server to the Client MUST be a CONNACK Packet */
436 if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_CONNECTING &&
437 packet_type != AWS_MQTT_PACKET_CONNACK) {
438 mqtt_connection_unlock_synced_data(connection);
439 AWS_LOGF_ERROR(
440 AWS_LS_MQTT_CLIENT,
441 "id=%p: First message received from the server was not a CONNACK. Terminating connection.",
442 (void *)connection);
443 aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_PROTOCOL_ERROR);
444 return aws_raise_error(AWS_ERROR_MQTT_PROTOCOL_ERROR);
445 }
446 mqtt_connection_unlock_synced_data(connection);
447 } /* END CRITICAL SECTION */
448
449 if (AWS_UNLIKELY(packet_type > AWS_MQTT_PACKET_DISCONNECT || packet_type < AWS_MQTT_PACKET_CONNECT)) {
450 AWS_LOGF_ERROR(
451 AWS_LS_MQTT_CLIENT,
452 "id=%p: Invalid packet type received %d. Terminating connection.",
453 (void *)connection,
454 packet_type);
455 return aws_raise_error(AWS_ERROR_MQTT_INVALID_PACKET_TYPE);
456 }
457
458 /* Handle the packet */
459 return s_packet_handlers[packet_type](connection, packet);
460 }
461
462 /**
463 * Handles incoming messages from the server.
464 */
s_process_read_message(struct aws_channel_handler * handler,struct aws_channel_slot * slot,struct aws_io_message * message)465 static int s_process_read_message(
466 struct aws_channel_handler *handler,
467 struct aws_channel_slot *slot,
468 struct aws_io_message *message) {
469
470 struct aws_mqtt_client_connection *connection = handler->impl;
471
472 if (message->message_type != AWS_IO_MESSAGE_APPLICATION_DATA || message->message_data.len < 1) {
473 return AWS_OP_ERR;
474 }
475
476 AWS_LOGF_TRACE(
477 AWS_LS_MQTT_CLIENT,
478 "id=%p: precessing read message of size %zu",
479 (void *)connection,
480 message->message_data.len);
481
482 /* This cursor will be updated as we read through the message. */
483 struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data);
484
485 /* If there's pending packet left over from last time, attempt to complete it. */
486 if (connection->thread_data.pending_packet.len) {
487 int result = AWS_OP_SUCCESS;
488
489 /* This determines how much to read from the message (min(expected_remaining, message.len)) */
490 size_t to_read = connection->thread_data.pending_packet.capacity - connection->thread_data.pending_packet.len;
491 /* This will be set to false if this message still won't complete the packet object. */
492 bool packet_complete = true;
493 if (to_read > message_cursor.len) {
494 to_read = message_cursor.len;
495 packet_complete = false;
496 }
497
498 /* Write the chunk to the buffer.
499 * This will either complete the packet, or be the entirety of message if more data is required. */
500 struct aws_byte_cursor chunk = aws_byte_cursor_advance(&message_cursor, to_read);
501 AWS_ASSERT(chunk.ptr); /* Guaranteed to be in bounds */
502 result = (int)aws_byte_buf_write_from_whole_cursor(&connection->thread_data.pending_packet, chunk) - 1;
503 if (result) {
504 goto handle_error;
505 }
506
507 /* If the packet is still incomplete, don't do anything with the data. */
508 if (!packet_complete) {
509 AWS_LOGF_TRACE(
510 AWS_LS_MQTT_CLIENT,
511 "id=%p: partial message is still incomplete, waiting on another read.",
512 (void *)connection);
513
514 goto cleanup;
515 }
516
517 /* Handle the completed pending packet */
518 struct aws_byte_cursor packet_data = aws_byte_cursor_from_buf(&connection->thread_data.pending_packet);
519 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: full mqtt packet re-assembled, dispatching.", (void *)connection);
520 result = s_process_mqtt_packet(connection, aws_mqtt_get_packet_type(packet_data.ptr), packet_data);
521
522 handle_error:
523 /* Clean up the pending packet */
524 aws_byte_buf_clean_up(&connection->thread_data.pending_packet);
525 AWS_ZERO_STRUCT(connection->thread_data.pending_packet);
526
527 if (result) {
528 return AWS_OP_ERR;
529 }
530 }
531
532 while (message_cursor.len) {
533
534 /* Temp byte cursor so we can decode the header without advancing message_cursor. */
535 struct aws_byte_cursor header_decode = message_cursor;
536
537 struct aws_mqtt_fixed_header packet_header;
538 AWS_ZERO_STRUCT(packet_header);
539 int result = aws_mqtt_fixed_header_decode(&header_decode, &packet_header);
540
541 /* Calculate how much data was read. */
542 const size_t fixed_header_size = message_cursor.len - header_decode.len;
543
544 if (result) {
545 if (aws_last_error() == AWS_ERROR_SHORT_BUFFER) {
546 /* Message data too short, store data and come back later. */
547 AWS_LOGF_TRACE(
548 AWS_LS_MQTT_CLIENT, "id=%p: message is incomplete, waiting on another read.", (void *)connection);
549 if (aws_byte_buf_init(
550 &connection->thread_data.pending_packet,
551 connection->allocator,
552 fixed_header_size + packet_header.remaining_length)) {
553
554 return AWS_OP_ERR;
555 }
556
557 /* Write the partial packet. */
558 if (!aws_byte_buf_write_from_whole_cursor(&connection->thread_data.pending_packet, message_cursor)) {
559 aws_byte_buf_clean_up(&connection->thread_data.pending_packet);
560 return AWS_OP_ERR;
561 }
562
563 aws_reset_error();
564 goto cleanup;
565 } else {
566 return AWS_OP_ERR;
567 }
568 }
569
570 struct aws_byte_cursor packet_data =
571 aws_byte_cursor_advance(&message_cursor, fixed_header_size + packet_header.remaining_length);
572 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: full mqtt packet read, dispatching.", (void *)connection);
573 s_process_mqtt_packet(connection, packet_header.packet_type, packet_data);
574 }
575
576 cleanup:
577 /* Do cleanup */
578 aws_channel_slot_increment_read_window(slot, message->message_data.len);
579 aws_mem_release(message->allocator, message);
580
581 return AWS_OP_SUCCESS;
582 }
583
s_shutdown(struct aws_channel_handler * handler,struct aws_channel_slot * slot,enum aws_channel_direction dir,int error_code,bool free_scarce_resources_immediately)584 static int s_shutdown(
585 struct aws_channel_handler *handler,
586 struct aws_channel_slot *slot,
587 enum aws_channel_direction dir,
588 int error_code,
589 bool free_scarce_resources_immediately) {
590
591 struct aws_mqtt_client_connection *connection = handler->impl;
592
593 if (dir == AWS_CHANNEL_DIR_WRITE) {
594 /* On closing write direction, send out disconnect packet before closing connection. */
595
596 if (!free_scarce_resources_immediately) {
597
598 if (error_code == AWS_OP_SUCCESS) {
599 AWS_LOGF_INFO(
600 AWS_LS_MQTT_CLIENT,
601 "id=%p: sending disconnect message as part of graceful shutdown.",
602 (void *)connection);
603 /* On clean shutdown, send the disconnect message */
604 struct aws_mqtt_packet_connection disconnect;
605 aws_mqtt_packet_disconnect_init(&disconnect);
606
607 struct aws_io_message *message = mqtt_get_message_for_packet(connection, &disconnect.fixed_header);
608 if (!message) {
609 goto done;
610 }
611
612 if (aws_mqtt_packet_connection_encode(&message->message_data, &disconnect)) {
613 AWS_LOGF_DEBUG(
614 AWS_LS_MQTT_CLIENT,
615 "id=%p: failed to encode courteous disconnect io message",
616 (void *)connection);
617 aws_mem_release(message->allocator, message);
618 goto done;
619 }
620
621 if (aws_channel_slot_send_message(slot, message, AWS_CHANNEL_DIR_WRITE)) {
622 AWS_LOGF_DEBUG(
623 AWS_LS_MQTT_CLIENT,
624 "id=%p: failed to send courteous disconnect io message",
625 (void *)connection);
626 aws_mem_release(message->allocator, message);
627 goto done;
628 }
629 }
630 }
631 }
632
633 done:
634 return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, free_scarce_resources_immediately);
635 }
636
s_initial_window_size(struct aws_channel_handler * handler)637 static size_t s_initial_window_size(struct aws_channel_handler *handler) {
638
639 (void)handler;
640
641 return SIZE_MAX;
642 }
643
s_destroy(struct aws_channel_handler * handler)644 static void s_destroy(struct aws_channel_handler *handler) {
645
646 struct aws_mqtt_client_connection *connection = handler->impl;
647 (void)connection;
648 }
649
s_message_overhead(struct aws_channel_handler * handler)650 static size_t s_message_overhead(struct aws_channel_handler *handler) {
651 (void)handler;
652 return 0;
653 }
654
aws_mqtt_get_client_channel_vtable(void)655 struct aws_channel_handler_vtable *aws_mqtt_get_client_channel_vtable(void) {
656
657 static struct aws_channel_handler_vtable s_vtable = {
658 .process_read_message = &s_process_read_message,
659 .process_write_message = NULL,
660 .increment_read_window = NULL,
661 .shutdown = &s_shutdown,
662 .initial_window_size = &s_initial_window_size,
663 .message_overhead = &s_message_overhead,
664 .destroy = &s_destroy,
665 };
666
667 return &s_vtable;
668 }
669
670 /*******************************************************************************
671 * Helpers
672 ******************************************************************************/
673
mqtt_get_message_for_packet(struct aws_mqtt_client_connection * connection,struct aws_mqtt_fixed_header * header)674 struct aws_io_message *mqtt_get_message_for_packet(
675 struct aws_mqtt_client_connection *connection,
676 struct aws_mqtt_fixed_header *header) {
677
678 const size_t required_length = 3 + header->remaining_length;
679
680 struct aws_io_message *message = aws_channel_acquire_message_from_pool(
681 connection->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, required_length);
682
683 AWS_LOGF_TRACE(
684 AWS_LS_MQTT_CLIENT,
685 "id=%p: Acquiring memory from pool of required_length %zu",
686 (void *)connection,
687 required_length);
688
689 return message;
690 }
691
692 /*******************************************************************************
693 * Requests
694 ******************************************************************************/
695
696 /* Send the request */
s_request_outgoing_task(struct aws_channel_task * task,void * arg,enum aws_task_status status)697 static void s_request_outgoing_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) {
698
699 struct aws_mqtt_request *request = arg;
700 struct aws_mqtt_client_connection *connection = request->connection;
701
702 if (status == AWS_TASK_STATUS_CANCELED) {
703 /* Connection lost before the request ever get send, check the request needs to be retried or not */
704 if (request->retryable) {
705 AWS_LOGF_DEBUG(
706 AWS_LS_MQTT_CLIENT,
707 "static: task id %p, was canceled due to the channel shutting down. Request for packet id "
708 "%" PRIu16 ". will be retried",
709 (void *)task,
710 request->packet_id);
711 /* put it into the offline queue. */
712 { /* BEGIN CRITICAL SECTION */
713 mqtt_connection_lock_synced_data(connection);
714 aws_linked_list_push_back(&connection->synced_data.pending_requests_list, &request->list_node);
715 mqtt_connection_unlock_synced_data(connection);
716 } /* END CRITICAL SECTION */
717 } else {
718 AWS_LOGF_DEBUG(
719 AWS_LS_MQTT_CLIENT,
720 "static: task id %p, was canceled due to the channel shutting down. Request for packet id "
721 "%" PRIu16 ". will NOT be retried, will be cancelled",
722 (void *)task,
723 request->packet_id);
724 /* Fire the callback and clean up the memory, as the connection get destroyed. */
725 if (request->on_complete) {
726 request->on_complete(
727 connection, request->packet_id, AWS_ERROR_MQTT_NOT_CONNECTED, request->on_complete_ud);
728 }
729 { /* BEGIN CRITICAL SECTION */
730 mqtt_connection_lock_synced_data(connection);
731 aws_hash_table_remove(
732 &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
733 aws_memory_pool_release(&connection->synced_data.requests_pool, request);
734 mqtt_connection_unlock_synced_data(connection);
735 } /* END CRITICAL SECTION */
736 }
737 return;
738 }
739
740 /* Send the request */
741 enum aws_mqtt_client_request_state state =
742 request->send_request(request->packet_id, !request->initiated, request->send_request_ud);
743 request->initiated = true;
744 int error_code = AWS_ERROR_SUCCESS;
745 switch (state) {
746 case AWS_MQTT_CLIENT_REQUEST_ERROR:
747 error_code = aws_last_error();
748 AWS_LOGF_ERROR(
749 AWS_LS_MQTT_CLIENT,
750 "id=%p: sending request %" PRIu16 " failed with error %d.",
751 (void *)request->connection,
752 request->packet_id,
753 error_code);
754 /* fall-thru */
755
756 case AWS_MQTT_CLIENT_REQUEST_COMPLETE:
757 AWS_LOGF_TRACE(
758 AWS_LS_MQTT_CLIENT,
759 "id=%p: sending request %" PRIu16 " complete, invoking on_complete callback.",
760 (void *)request->connection,
761 request->packet_id);
762 /* If the send_request function reports the request is complete,
763 * remove from the hash table and call the callback. */
764 if (request->on_complete) {
765 request->on_complete(connection, request->packet_id, error_code, request->on_complete_ud);
766 }
767 { /* BEGIN CRITICAL SECTION */
768 mqtt_connection_lock_synced_data(connection);
769 aws_hash_table_remove(
770 &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
771 aws_memory_pool_release(&connection->synced_data.requests_pool, request);
772 mqtt_connection_unlock_synced_data(connection);
773 } /* END CRITICAL SECTION */
774 break;
775
776 case AWS_MQTT_CLIENT_REQUEST_ONGOING:
777 AWS_LOGF_TRACE(
778 AWS_LS_MQTT_CLIENT,
779 "id=%p: request %" PRIu16 " sent, but waiting on an acknowledgement from peer.",
780 (void *)request->connection,
781 request->packet_id);
782 /* Put the request into the ongoing list */
783 aws_linked_list_push_back(&connection->thread_data.ongoing_requests_list, &request->list_node);
784 break;
785 }
786 }
787
mqtt_create_request(struct aws_mqtt_client_connection * connection,aws_mqtt_send_request_fn * send_request,void * send_request_ud,aws_mqtt_op_complete_fn * on_complete,void * on_complete_ud,bool noRetry)788 uint16_t mqtt_create_request(
789 struct aws_mqtt_client_connection *connection,
790 aws_mqtt_send_request_fn *send_request,
791 void *send_request_ud,
792 aws_mqtt_op_complete_fn *on_complete,
793 void *on_complete_ud,
794 bool noRetry) {
795
796 AWS_ASSERT(connection);
797 AWS_ASSERT(send_request);
798 struct aws_mqtt_request *next_request = NULL;
799 bool should_schedule_task = false;
800 struct aws_channel *channel = NULL;
801 { /* BEGIN CRITICAL SECTION */
802 mqtt_connection_lock_synced_data(connection);
803 if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
804 mqtt_connection_unlock_synced_data(connection);
805 /* User requested disconnecting, ensure no new requests are made until the channel finished shutting
806 * down. */
807 AWS_LOGF_ERROR(
808 AWS_LS_MQTT_CLIENT,
809 "id=%p: Disconnect requested, stop creating any new request until disconnect process finishes.",
810 (void *)connection);
811 aws_raise_error(AWS_ERROR_MQTT_CONNECTION_DISCONNECTING);
812 return 0;
813 }
814
815 if (noRetry && connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED) {
816 mqtt_connection_unlock_synced_data(connection);
817 /* Not offline queueing QoS 0 publish or PINGREQ. Fail the call. */
818 AWS_LOGF_DEBUG(
819 AWS_LS_MQTT_CLIENT,
820 "id=%p: Not currently connected. No offline queueing for QoS 0 publish or pingreq.",
821 (void *)connection);
822 aws_raise_error(AWS_ERROR_MQTT_NOT_CONNECTED);
823 return 0;
824 }
825 /**
826 * Find a free packet ID.
827 * QoS 0 PUBLISH packets don't actually need an ID on the wire,
828 * but we assign them internally anyway just so everything has a unique ID.
829 *
830 * Yes, this is an O(N) search.
831 * We remember the last ID we assigned, so it's O(1) in the common case.
832 * But it's theoretically possible to reach O(N) where N is just above 64000
833 * if the user is letting a ton of un-ack'd messages queue up
834 */
835 uint16_t search_start = connection->synced_data.packet_id;
836 struct aws_hash_element *elem = NULL;
837 while (true) {
838 /* Increment ID, watch out for overflow, ID cannot be 0 */
839 if (connection->synced_data.packet_id == UINT16_MAX) {
840 connection->synced_data.packet_id = 1;
841 } else {
842 connection->synced_data.packet_id++;
843 }
844
845 /* Is there already an outstanding request using this ID? */
846 aws_hash_table_find(
847 &connection->synced_data.outstanding_requests_table, &connection->synced_data.packet_id, &elem);
848
849 if (elem == NULL) {
850 /* Found a free ID! Break out of loop */
851 break;
852 } else if (connection->synced_data.packet_id == search_start) {
853 /* Every ID is taken */
854 mqtt_connection_unlock_synced_data(connection);
855 AWS_LOGF_ERROR(
856 AWS_LS_MQTT_CLIENT,
857 "id=%p: Queue is full. No more packet IDs are available at this time.",
858 (void *)connection);
859 aws_raise_error(AWS_ERROR_MQTT_QUEUE_FULL);
860 return 0;
861 }
862 }
863
864 next_request = aws_memory_pool_acquire(&connection->synced_data.requests_pool);
865 if (!next_request) {
866 mqtt_connection_unlock_synced_data(connection);
867 return 0;
868 }
869 memset(next_request, 0, sizeof(struct aws_mqtt_request));
870
871 next_request->packet_id = connection->synced_data.packet_id;
872
873 if (aws_hash_table_put(
874 &connection->synced_data.outstanding_requests_table, &next_request->packet_id, next_request, NULL)) {
875 /* failed to put the next request into the table */
876 aws_memory_pool_release(&connection->synced_data.requests_pool, next_request);
877 mqtt_connection_unlock_synced_data(connection);
878 return 0;
879 }
880 /* Store the request by packet_id */
881 next_request->allocator = connection->allocator;
882 next_request->connection = connection;
883 next_request->initiated = false;
884 next_request->retryable = !noRetry;
885 next_request->send_request = send_request;
886 next_request->send_request_ud = send_request_ud;
887 next_request->on_complete = on_complete;
888 next_request->on_complete_ud = on_complete_ud;
889 aws_channel_task_init(
890 &next_request->outgoing_task, s_request_outgoing_task, next_request, "mqtt_outgoing_request_task");
891 if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED) {
892 aws_linked_list_push_back(&connection->synced_data.pending_requests_list, &next_request->list_node);
893 } else {
894 AWS_ASSERT(connection->slot);
895 AWS_ASSERT(connection->slot->channel);
896 should_schedule_task = true;
897 channel = connection->slot->channel;
898 /* keep the channel alive until the task is scheduled */
899 aws_channel_acquire_hold(channel);
900 }
901 mqtt_connection_unlock_synced_data(connection);
902 } /* END CRITICAL SECTION */
903 if (should_schedule_task) {
904 AWS_LOGF_TRACE(
905 AWS_LS_MQTT_CLIENT,
906 "id=%p: Currently not in the event-loop thread, scheduling a task to send message id %" PRIu16 ".",
907 (void *)connection,
908 next_request->packet_id);
909 aws_channel_schedule_task_now(channel, &next_request->outgoing_task);
910 /* release the refcount we hold with the protection of lock */
911 aws_channel_release_hold(channel);
912 }
913
914 return next_request->packet_id;
915 }
916
mqtt_request_complete(struct aws_mqtt_client_connection * connection,int error_code,uint16_t packet_id)917 void mqtt_request_complete(struct aws_mqtt_client_connection *connection, int error_code, uint16_t packet_id) {
918
919 AWS_LOGF_TRACE(
920 AWS_LS_MQTT_CLIENT,
921 "id=%p: message id %" PRIu16 " completed with error code %d, removing from outstanding requests list.",
922 (void *)connection,
923 packet_id,
924 error_code);
925
926 bool found_request = false;
927 aws_mqtt_op_complete_fn *on_complete = NULL;
928 void *on_complete_ud = NULL;
929
930 { /* BEGIN CRITICAL SECTION */
931 mqtt_connection_lock_synced_data(connection);
932 struct aws_hash_element *elem = NULL;
933 aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &packet_id, &elem);
934 if (elem != NULL) {
935 found_request = true;
936
937 struct aws_mqtt_request *request = elem->value;
938 on_complete = request->on_complete;
939 on_complete_ud = request->on_complete_ud;
940
941 /* clean up request resources */
942 aws_hash_table_remove_element(&connection->synced_data.outstanding_requests_table, elem);
943 /* remove the request from the list, which is thread_data.ongoing_requests_list */
944 aws_linked_list_remove(&request->list_node);
945 aws_memory_pool_release(&connection->synced_data.requests_pool, request);
946 }
947 mqtt_connection_unlock_synced_data(connection);
948 } /* END CRITICAL SECTION */
949
950 if (!found_request) {
951 AWS_LOGF_DEBUG(
952 AWS_LS_MQTT_CLIENT,
953 "id=%p: received completion for message id %" PRIu16
954 " but no outstanding request exists. Assuming this is an ack of a resend when the first request has "
955 "already completed.",
956 (void *)connection,
957 packet_id);
958 return;
959 }
960
961 /* Invoke the complete callback. */
962 if (on_complete) {
963 on_complete(connection, packet_id, error_code, on_complete_ud);
964 }
965 }
966
967 struct mqtt_shutdown_task {
968 int error_code;
969 struct aws_channel_task task;
970 };
971
s_mqtt_disconnect_task(struct aws_channel_task * channel_task,void * arg,enum aws_task_status status)972 static void s_mqtt_disconnect_task(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
973
974 (void)status;
975
976 struct mqtt_shutdown_task *task = AWS_CONTAINER_OF(channel_task, struct mqtt_shutdown_task, task);
977 struct aws_mqtt_client_connection *connection = arg;
978
979 AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Doing disconnect", (void *)connection);
980 { /* BEGIN CRITICAL SECTION */
981 mqtt_connection_lock_synced_data(connection);
982 /* If there is an outstanding reconnect task, cancel it */
983 if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING && connection->reconnect_task) {
984 aws_atomic_store_ptr(&connection->reconnect_task->connection_ptr, NULL);
985 /* If the reconnect_task isn't scheduled, free it */
986 if (connection->reconnect_task && !connection->reconnect_task->task.timestamp) {
987 aws_mem_release(connection->reconnect_task->allocator, connection->reconnect_task);
988 }
989 connection->reconnect_task = NULL;
990 }
991 mqtt_connection_unlock_synced_data(connection);
992 } /* END CRITICAL SECTION */
993
994 if (connection->slot && connection->slot->channel) {
995 aws_channel_shutdown(connection->slot->channel, task->error_code);
996 }
997
998 aws_mem_release(connection->allocator, task);
999 }
1000
mqtt_disconnect_impl(struct aws_mqtt_client_connection * connection,int error_code)1001 void mqtt_disconnect_impl(struct aws_mqtt_client_connection *connection, int error_code) {
1002 if (connection->slot) {
1003 struct mqtt_shutdown_task *shutdown_task =
1004 aws_mem_calloc(connection->allocator, 1, sizeof(struct mqtt_shutdown_task));
1005 shutdown_task->error_code = error_code;
1006 aws_channel_task_init(&shutdown_task->task, s_mqtt_disconnect_task, connection, "mqtt_disconnect");
1007 aws_channel_schedule_task_now(connection->slot->channel, &shutdown_task->task);
1008 }
1009 }
1010