1 /**
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  * SPDX-License-Identifier: Apache-2.0.
4  */
5 #include <aws/crt/mqtt/MqttClient.h>
6 
7 #include <aws/crt/StlAllocator.h>
8 #include <aws/crt/http/HttpProxyStrategy.h>
9 #include <aws/crt/http/HttpRequestResponse.h>
10 #include <aws/crt/io/Bootstrap.h>
11 
12 #include <utility>
13 
14 namespace Aws
15 {
16     namespace Crt
17     {
18         namespace Mqtt
19         {
s_onConnectionInterrupted(aws_mqtt_client_connection *,int errorCode,void * userData)20             void MqttConnection::s_onConnectionInterrupted(aws_mqtt_client_connection *, int errorCode, void *userData)
21             {
22                 auto connWrapper = reinterpret_cast<MqttConnection *>(userData);
23                 if (connWrapper->OnConnectionInterrupted)
24                 {
25                     connWrapper->OnConnectionInterrupted(*connWrapper, errorCode);
26                 }
27             }
28 
s_onConnectionResumed(aws_mqtt_client_connection *,ReturnCode returnCode,bool sessionPresent,void * userData)29             void MqttConnection::s_onConnectionResumed(
30                 aws_mqtt_client_connection *,
31                 ReturnCode returnCode,
32                 bool sessionPresent,
33                 void *userData)
34             {
35                 auto connWrapper = reinterpret_cast<MqttConnection *>(userData);
36                 if (connWrapper->OnConnectionResumed)
37                 {
38                     connWrapper->OnConnectionResumed(*connWrapper, returnCode, sessionPresent);
39                 }
40             }
41 
s_onConnectionCompleted(aws_mqtt_client_connection *,int errorCode,enum aws_mqtt_connect_return_code returnCode,bool sessionPresent,void * userData)42             void MqttConnection::s_onConnectionCompleted(
43                 aws_mqtt_client_connection *,
44                 int errorCode,
45                 enum aws_mqtt_connect_return_code returnCode,
46                 bool sessionPresent,
47                 void *userData)
48             {
49                 auto connWrapper = reinterpret_cast<MqttConnection *>(userData);
50                 if (connWrapper->OnConnectionCompleted)
51                 {
52                     connWrapper->OnConnectionCompleted(*connWrapper, errorCode, returnCode, sessionPresent);
53                 }
54             }
55 
s_onDisconnect(aws_mqtt_client_connection *,void * userData)56             void MqttConnection::s_onDisconnect(aws_mqtt_client_connection *, void *userData)
57             {
58                 auto connWrapper = reinterpret_cast<MqttConnection *>(userData);
59                 if (connWrapper->OnDisconnect)
60                 {
61                     connWrapper->OnDisconnect(*connWrapper);
62                 }
63             }
64 
65             struct PubCallbackData
66             {
PubCallbackDataAws::Crt::Mqtt::PubCallbackData67                 PubCallbackData() : connection(nullptr), allocator(nullptr) {}
68 
69                 MqttConnection *connection;
70                 OnMessageReceivedHandler onMessageReceived;
71                 Allocator *allocator;
72             };
73 
s_cleanUpOnPublishData(void * userData)74             static void s_cleanUpOnPublishData(void *userData)
75             {
76                 auto callbackData = reinterpret_cast<PubCallbackData *>(userData);
77                 Crt::Delete(callbackData, callbackData->allocator);
78             }
79 
s_onPublish(aws_mqtt_client_connection *,const aws_byte_cursor * topic,const aws_byte_cursor * payload,bool dup,enum aws_mqtt_qos qos,bool retain,void * userData)80             void MqttConnection::s_onPublish(
81                 aws_mqtt_client_connection *,
82                 const aws_byte_cursor *topic,
83                 const aws_byte_cursor *payload,
84                 bool dup,
85                 enum aws_mqtt_qos qos,
86                 bool retain,
87                 void *userData)
88             {
89                 auto callbackData = reinterpret_cast<PubCallbackData *>(userData);
90 
91                 if (callbackData->onMessageReceived)
92                 {
93                     String topicStr(reinterpret_cast<char *>(topic->ptr), topic->len);
94                     ByteBuf payloadBuf = aws_byte_buf_from_array(payload->ptr, payload->len);
95                     callbackData->onMessageReceived(
96                         *(callbackData->connection), topicStr, payloadBuf, dup, qos, retain);
97                 }
98             }
99 
100             struct OpCompleteCallbackData
101             {
OpCompleteCallbackDataAws::Crt::Mqtt::OpCompleteCallbackData102                 OpCompleteCallbackData() : connection(nullptr), topic(nullptr), allocator(nullptr) {}
103 
104                 MqttConnection *connection;
105                 OnOperationCompleteHandler onOperationComplete;
106                 const char *topic;
107                 Allocator *allocator;
108             };
109 
s_onOpComplete(aws_mqtt_client_connection *,uint16_t packetId,int errorCode,void * userData)110             void MqttConnection::s_onOpComplete(
111                 aws_mqtt_client_connection *,
112                 uint16_t packetId,
113                 int errorCode,
114                 void *userData)
115             {
116                 auto callbackData = reinterpret_cast<OpCompleteCallbackData *>(userData);
117 
118                 if (callbackData->onOperationComplete)
119                 {
120                     callbackData->onOperationComplete(*callbackData->connection, packetId, errorCode);
121                 }
122 
123                 if (callbackData->topic)
124                 {
125                     aws_mem_release(
126                         callbackData->allocator, reinterpret_cast<void *>(const_cast<char *>(callbackData->topic)));
127                 }
128 
129                 Crt::Delete(callbackData, callbackData->allocator);
130             }
131 
132             struct SubAckCallbackData
133             {
SubAckCallbackDataAws::Crt::Mqtt::SubAckCallbackData134                 SubAckCallbackData() : connection(nullptr), topic(nullptr), allocator(nullptr) {}
135 
136                 MqttConnection *connection;
137                 OnSubAckHandler onSubAck;
138                 const char *topic;
139                 Allocator *allocator;
140             };
141 
s_onSubAck(aws_mqtt_client_connection *,uint16_t packetId,const struct aws_byte_cursor * topic,enum aws_mqtt_qos qos,int errorCode,void * userData)142             void MqttConnection::s_onSubAck(
143                 aws_mqtt_client_connection *,
144                 uint16_t packetId,
145                 const struct aws_byte_cursor *topic,
146                 enum aws_mqtt_qos qos,
147                 int errorCode,
148                 void *userData)
149             {
150                 auto callbackData = reinterpret_cast<SubAckCallbackData *>(userData);
151 
152                 if (callbackData->onSubAck)
153                 {
154                     String topicStr(reinterpret_cast<char *>(topic->ptr), topic->len);
155                     callbackData->onSubAck(*callbackData->connection, packetId, topicStr, qos, errorCode);
156                 }
157 
158                 if (callbackData->topic)
159                 {
160                     aws_mem_release(
161                         callbackData->allocator, reinterpret_cast<void *>(const_cast<char *>(callbackData->topic)));
162                 }
163 
164                 Crt::Delete(callbackData, callbackData->allocator);
165             }
166 
167             struct MultiSubAckCallbackData
168             {
MultiSubAckCallbackDataAws::Crt::Mqtt::MultiSubAckCallbackData169                 MultiSubAckCallbackData() : connection(nullptr), topic(nullptr), allocator(nullptr) {}
170 
171                 MqttConnection *connection;
172                 OnMultiSubAckHandler onSubAck;
173                 const char *topic;
174                 Allocator *allocator;
175             };
176 
s_onMultiSubAck(aws_mqtt_client_connection *,uint16_t packetId,const struct aws_array_list * topicSubacks,int errorCode,void * userData)177             void MqttConnection::s_onMultiSubAck(
178                 aws_mqtt_client_connection *,
179                 uint16_t packetId,
180                 const struct aws_array_list *topicSubacks,
181                 int errorCode,
182                 void *userData)
183             {
184                 auto callbackData = reinterpret_cast<MultiSubAckCallbackData *>(userData);
185 
186                 if (callbackData->onSubAck)
187                 {
188                     size_t length = aws_array_list_length(topicSubacks);
189                     Vector<String> topics;
190                     topics.reserve(length);
191                     QOS qos = AWS_MQTT_QOS_AT_MOST_ONCE;
192                     for (size_t i = 0; i < length; ++i)
193                     {
194                         aws_mqtt_topic_subscription *subscription = NULL;
195                         aws_array_list_get_at(topicSubacks, &subscription, i);
196                         topics.push_back(
197                             String(reinterpret_cast<char *>(subscription->topic.ptr), subscription->topic.len));
198                         qos = subscription->qos;
199                     }
200 
201                     callbackData->onSubAck(*callbackData->connection, packetId, topics, qos, errorCode);
202                 }
203 
204                 if (callbackData->topic)
205                 {
206                     aws_mem_release(
207                         callbackData->allocator, reinterpret_cast<void *>(const_cast<char *>(callbackData->topic)));
208                 }
209 
210                 Crt::Delete(callbackData, callbackData->allocator);
211             }
212 
s_connectionInit(MqttConnection * self,const char * hostName,uint16_t port,const Io::SocketOptions & socketOptions)213             void MqttConnection::s_connectionInit(
214                 MqttConnection *self,
215                 const char *hostName,
216                 uint16_t port,
217                 const Io::SocketOptions &socketOptions)
218             {
219 
220                 self->m_hostName = String(hostName);
221                 self->m_port = port;
222                 self->m_socketOptions = socketOptions;
223 
224                 self->m_underlyingConnection = aws_mqtt_client_connection_new(self->m_owningClient);
225 
226                 if (self->m_underlyingConnection)
227                 {
228                     aws_mqtt_client_connection_set_connection_interruption_handlers(
229                         self->m_underlyingConnection,
230                         MqttConnection::s_onConnectionInterrupted,
231                         self,
232                         MqttConnection::s_onConnectionResumed,
233                         self);
234                 }
235             }
236 
s_onWebsocketHandshake(struct aws_http_message * rawRequest,void * user_data,aws_mqtt_transform_websocket_handshake_complete_fn * complete_fn,void * complete_ctx)237             void MqttConnection::s_onWebsocketHandshake(
238                 struct aws_http_message *rawRequest,
239                 void *user_data,
240                 aws_mqtt_transform_websocket_handshake_complete_fn *complete_fn,
241                 void *complete_ctx)
242             {
243                 auto connection = reinterpret_cast<MqttConnection *>(user_data);
244 
245                 Allocator *allocator = connection->m_owningClient->allocator;
246                 // we have to do this because of private constructors.
247                 auto toSeat =
248                     reinterpret_cast<Http::HttpRequest *>(aws_mem_acquire(allocator, sizeof(Http::HttpRequest)));
249                 toSeat = new (toSeat) Http::HttpRequest(allocator, rawRequest);
250 
251                 std::shared_ptr<Http::HttpRequest> request = std::shared_ptr<Http::HttpRequest>(
252                     toSeat, [allocator](Http::HttpRequest *ptr) { Crt::Delete(ptr, allocator); });
253 
254                 auto onInterceptComplete =
255                     [complete_fn,
256                      complete_ctx](const std::shared_ptr<Http::HttpRequest> &transformedRequest, int errorCode) {
257                         complete_fn(transformedRequest->GetUnderlyingMessage(), errorCode, complete_ctx);
258                     };
259 
260                 connection->WebsocketInterceptor(request, onInterceptComplete);
261             }
262 
MqttConnection(aws_mqtt_client * client,const char * hostName,uint16_t port,const Io::SocketOptions & socketOptions,const Crt::Io::TlsContext & tlsContext,bool useWebsocket)263             MqttConnection::MqttConnection(
264                 aws_mqtt_client *client,
265                 const char *hostName,
266                 uint16_t port,
267                 const Io::SocketOptions &socketOptions,
268                 const Crt::Io::TlsContext &tlsContext,
269                 bool useWebsocket) noexcept
270                 : m_owningClient(client), m_tlsContext(tlsContext), m_tlsOptions(tlsContext.NewConnectionOptions()),
271                   m_onAnyCbData(nullptr), m_useTls(true), m_useWebsocket(useWebsocket)
272             {
273                 s_connectionInit(this, hostName, port, socketOptions);
274             }
275 
MqttConnection(aws_mqtt_client * client,const char * hostName,uint16_t port,const Io::SocketOptions & socketOptions,bool useWebsocket)276             MqttConnection::MqttConnection(
277                 aws_mqtt_client *client,
278                 const char *hostName,
279                 uint16_t port,
280                 const Io::SocketOptions &socketOptions,
281                 bool useWebsocket) noexcept
282                 : m_owningClient(client), m_onAnyCbData(nullptr), m_useTls(false), m_useWebsocket(useWebsocket)
283             {
284                 s_connectionInit(this, hostName, port, socketOptions);
285             }
286 
~MqttConnection()287             MqttConnection::~MqttConnection()
288             {
289                 if (*this)
290                 {
291                     aws_mqtt_client_connection_release(m_underlyingConnection);
292 
293                     if (m_onAnyCbData)
294                     {
295                         auto pubCallbackData = reinterpret_cast<PubCallbackData *>(m_onAnyCbData);
296                         Crt::Delete(pubCallbackData, pubCallbackData->allocator);
297                     }
298                 }
299             }
300 
operator bool() const301             MqttConnection::operator bool() const noexcept { return m_underlyingConnection != nullptr; }
302 
LastError() const303             int MqttConnection::LastError() const noexcept { return aws_last_error(); }
304 
SetWill(const char * topic,QOS qos,bool retain,const ByteBuf & payload)305             bool MqttConnection::SetWill(const char *topic, QOS qos, bool retain, const ByteBuf &payload) noexcept
306             {
307                 ByteBuf topicBuf = aws_byte_buf_from_c_str(topic);
308                 ByteCursor topicCur = aws_byte_cursor_from_buf(&topicBuf);
309                 ByteCursor payloadCur = aws_byte_cursor_from_buf(&payload);
310 
311                 return aws_mqtt_client_connection_set_will(
312                            m_underlyingConnection, &topicCur, qos, retain, &payloadCur) == 0;
313             }
314 
SetLogin(const char * userName,const char * password)315             bool MqttConnection::SetLogin(const char *userName, const char *password) noexcept
316             {
317                 ByteBuf userNameBuf = aws_byte_buf_from_c_str(userName);
318                 ByteCursor userNameCur = aws_byte_cursor_from_buf(&userNameBuf);
319 
320                 ByteCursor *pwdCurPtr = nullptr;
321                 ByteCursor pwdCur;
322 
323                 if (password)
324                 {
325                     pwdCur = ByteCursorFromCString(password);
326                     pwdCurPtr = &pwdCur;
327                 }
328                 return aws_mqtt_client_connection_set_login(m_underlyingConnection, &userNameCur, pwdCurPtr) == 0;
329             }
330 
SetWebsocketProxyOptions(const Http::HttpClientConnectionProxyOptions & proxyOptions)331             bool MqttConnection::SetWebsocketProxyOptions(
332                 const Http::HttpClientConnectionProxyOptions &proxyOptions) noexcept
333             {
334                 m_proxyOptions = proxyOptions;
335                 return true;
336             }
337 
SetHttpProxyOptions(const Http::HttpClientConnectionProxyOptions & proxyOptions)338             bool MqttConnection::SetHttpProxyOptions(
339                 const Http::HttpClientConnectionProxyOptions &proxyOptions) noexcept
340             {
341                 m_proxyOptions = proxyOptions;
342                 return true;
343             }
344 
SetReconnectTimeout(uint64_t min_seconds,uint64_t max_seconds)345             bool MqttConnection::SetReconnectTimeout(uint64_t min_seconds, uint64_t max_seconds) noexcept
346             {
347                 return aws_mqtt_client_connection_set_reconnect_timeout(
348                            m_underlyingConnection, min_seconds, max_seconds) == 0;
349             }
350 
Connect(const char * clientId,bool cleanSession,uint16_t keepAliveTime,uint32_t pingTimeoutMs,uint32_t protocolOperationTimeoutMs)351             bool MqttConnection::Connect(
352                 const char *clientId,
353                 bool cleanSession,
354                 uint16_t keepAliveTime,
355                 uint32_t pingTimeoutMs,
356                 uint32_t protocolOperationTimeoutMs) noexcept
357             {
358                 aws_mqtt_connection_options options;
359                 AWS_ZERO_STRUCT(options);
360                 options.client_id = aws_byte_cursor_from_c_str(clientId);
361                 options.host_name = aws_byte_cursor_from_array(
362                     reinterpret_cast<const uint8_t *>(m_hostName.data()), m_hostName.length());
363                 options.tls_options =
364                     m_useTls ? const_cast<aws_tls_connection_options *>(m_tlsOptions.GetUnderlyingHandle()) : nullptr;
365                 options.port = m_port;
366                 options.socket_options = &m_socketOptions.GetImpl();
367                 options.clean_session = cleanSession;
368                 options.keep_alive_time_secs = keepAliveTime;
369                 options.ping_timeout_ms = pingTimeoutMs;
370                 options.protocol_operation_timeout_ms = protocolOperationTimeoutMs;
371                 options.on_connection_complete = MqttConnection::s_onConnectionCompleted;
372                 options.user_data = this;
373 
374                 if (m_useWebsocket)
375                 {
376                     if (WebsocketInterceptor)
377                     {
378                         if (aws_mqtt_client_connection_use_websockets(
379                                 m_underlyingConnection, MqttConnection::s_onWebsocketHandshake, this, nullptr, nullptr))
380                         {
381                             return false;
382                         }
383                     }
384                     else
385                     {
386                         if (aws_mqtt_client_connection_use_websockets(
387                                 m_underlyingConnection, nullptr, nullptr, nullptr, nullptr))
388                         {
389                             return false;
390                         }
391                     }
392                 }
393 
394                 if (m_proxyOptions)
395                 {
396                     struct aws_http_proxy_options proxyOptions;
397                     m_proxyOptions->InitializeRawProxyOptions(proxyOptions);
398 
399                     if (aws_mqtt_client_connection_set_http_proxy_options(m_underlyingConnection, &proxyOptions))
400                     {
401                         return false;
402                     }
403                 }
404 
405                 return aws_mqtt_client_connection_connect(m_underlyingConnection, &options) == AWS_OP_SUCCESS;
406             }
407 
Disconnect()408             bool MqttConnection::Disconnect() noexcept
409             {
410                 return aws_mqtt_client_connection_disconnect(
411                            m_underlyingConnection, MqttConnection::s_onDisconnect, this) == AWS_OP_SUCCESS;
412             }
413 
GetUnderlyingConnection()414             aws_mqtt_client_connection *MqttConnection::GetUnderlyingConnection() noexcept
415             {
416                 return m_underlyingConnection;
417             }
418 
SetOnMessageHandler(OnPublishReceivedHandler && onPublish)419             bool MqttConnection::SetOnMessageHandler(OnPublishReceivedHandler &&onPublish) noexcept
420             {
421                 return SetOnMessageHandler(
422                     [onPublish](
423                         MqttConnection &connection, const String &topic, const ByteBuf &payload, bool, QOS, bool) {
424                         onPublish(connection, topic, payload);
425                     });
426             }
427 
SetOnMessageHandler(OnMessageReceivedHandler && onMessage)428             bool MqttConnection::SetOnMessageHandler(OnMessageReceivedHandler &&onMessage) noexcept
429             {
430                 auto pubCallbackData = Aws::Crt::New<PubCallbackData>(m_owningClient->allocator);
431 
432                 if (!pubCallbackData)
433                 {
434                     return false;
435                 }
436 
437                 pubCallbackData->connection = this;
438                 pubCallbackData->onMessageReceived = std::move(onMessage);
439                 pubCallbackData->allocator = m_owningClient->allocator;
440 
441                 if (!aws_mqtt_client_connection_set_on_any_publish_handler(
442                         m_underlyingConnection, s_onPublish, pubCallbackData))
443                 {
444                     m_onAnyCbData = reinterpret_cast<void *>(pubCallbackData);
445                     return true;
446                 }
447 
448                 Aws::Crt::Delete(pubCallbackData, pubCallbackData->allocator);
449                 return false;
450             }
451 
Subscribe(const char * topicFilter,QOS qos,OnPublishReceivedHandler && onPublish,OnSubAckHandler && onSubAck)452             uint16_t MqttConnection::Subscribe(
453                 const char *topicFilter,
454                 QOS qos,
455                 OnPublishReceivedHandler &&onPublish,
456                 OnSubAckHandler &&onSubAck) noexcept
457             {
458                 return Subscribe(
459                     topicFilter,
460                     qos,
461                     [onPublish](
462                         MqttConnection &connection, const String &topic, const ByteBuf &payload, bool, QOS, bool) {
463                         onPublish(connection, topic, payload);
464                     },
465                     std::move(onSubAck));
466             }
467 
Subscribe(const char * topicFilter,QOS qos,OnMessageReceivedHandler && onMessage,OnSubAckHandler && onSubAck)468             uint16_t MqttConnection::Subscribe(
469                 const char *topicFilter,
470                 QOS qos,
471                 OnMessageReceivedHandler &&onMessage,
472                 OnSubAckHandler &&onSubAck) noexcept
473             {
474                 auto pubCallbackData = Crt::New<PubCallbackData>(m_owningClient->allocator);
475 
476                 if (!pubCallbackData)
477                 {
478                     return 0;
479                 }
480 
481                 pubCallbackData->connection = this;
482                 pubCallbackData->onMessageReceived = std::move(onMessage);
483                 pubCallbackData->allocator = m_owningClient->allocator;
484 
485                 auto subAckCallbackData = Crt::New<SubAckCallbackData>(m_owningClient->allocator);
486 
487                 if (!subAckCallbackData)
488                 {
489                     Crt::Delete(pubCallbackData, m_owningClient->allocator);
490                     return 0;
491                 }
492 
493                 subAckCallbackData->connection = this;
494                 subAckCallbackData->allocator = m_owningClient->allocator;
495                 subAckCallbackData->onSubAck = std::move(onSubAck);
496                 subAckCallbackData->topic = nullptr;
497                 subAckCallbackData->allocator = m_owningClient->allocator;
498 
499                 ByteBuf topicFilterBuf = aws_byte_buf_from_c_str(topicFilter);
500                 ByteCursor topicFilterCur = aws_byte_cursor_from_buf(&topicFilterBuf);
501 
502                 uint16_t packetId = aws_mqtt_client_connection_subscribe(
503                     m_underlyingConnection,
504                     &topicFilterCur,
505                     qos,
506                     s_onPublish,
507                     pubCallbackData,
508                     s_cleanUpOnPublishData,
509                     s_onSubAck,
510                     subAckCallbackData);
511 
512                 if (!packetId)
513                 {
514                     Crt::Delete(pubCallbackData, pubCallbackData->allocator);
515                     Crt::Delete(subAckCallbackData, subAckCallbackData->allocator);
516                 }
517 
518                 return packetId;
519             }
520 
Subscribe(const Vector<std::pair<const char *,OnPublishReceivedHandler>> & topicFilters,QOS qos,OnMultiSubAckHandler && onSubAck)521             uint16_t MqttConnection::Subscribe(
522                 const Vector<std::pair<const char *, OnPublishReceivedHandler>> &topicFilters,
523                 QOS qos,
524                 OnMultiSubAckHandler &&onSubAck) noexcept
525             {
526                 Vector<std::pair<const char *, OnMessageReceivedHandler>> newTopicFilters;
527                 newTopicFilters.reserve(topicFilters.size());
528                 for (const auto &pair : topicFilters)
529                 {
530                     const OnPublishReceivedHandler &pubHandler = pair.second;
531                     newTopicFilters.emplace_back(
532                         pair.first,
533                         [pubHandler](
534                             MqttConnection &connection, const String &topic, const ByteBuf &payload, bool, QOS, bool) {
535                             pubHandler(connection, topic, payload);
536                         });
537                 }
538                 return Subscribe(newTopicFilters, qos, std::move(onSubAck));
539             }
540 
Subscribe(const Vector<std::pair<const char *,OnMessageReceivedHandler>> & topicFilters,QOS qos,OnMultiSubAckHandler && onSubAck)541             uint16_t MqttConnection::Subscribe(
542                 const Vector<std::pair<const char *, OnMessageReceivedHandler>> &topicFilters,
543                 QOS qos,
544                 OnMultiSubAckHandler &&onSubAck) noexcept
545             {
546                 uint16_t packetId = 0;
547                 auto subAckCallbackData = Crt::New<MultiSubAckCallbackData>(m_owningClient->allocator);
548 
549                 if (!subAckCallbackData)
550                 {
551                     return 0;
552                 }
553 
554                 aws_array_list multiPub;
555                 AWS_ZERO_STRUCT(multiPub);
556 
557                 if (aws_array_list_init_dynamic(
558                         &multiPub, m_owningClient->allocator, topicFilters.size(), sizeof(aws_mqtt_topic_subscription)))
559                 {
560                     Crt::Delete(subAckCallbackData, m_owningClient->allocator);
561                     return 0;
562                 }
563 
564                 for (auto &topicFilter : topicFilters)
565                 {
566                     auto pubCallbackData = Crt::New<PubCallbackData>(m_owningClient->allocator);
567 
568                     if (!pubCallbackData)
569                     {
570                         goto clean_up;
571                     }
572 
573                     pubCallbackData->connection = this;
574                     pubCallbackData->onMessageReceived = topicFilter.second;
575                     pubCallbackData->allocator = m_owningClient->allocator;
576 
577                     ByteBuf topicFilterBuf = aws_byte_buf_from_c_str(topicFilter.first);
578                     ByteCursor topicFilterCur = aws_byte_cursor_from_buf(&topicFilterBuf);
579 
580                     aws_mqtt_topic_subscription subscription;
581                     subscription.on_cleanup = s_cleanUpOnPublishData;
582                     subscription.on_publish = s_onPublish;
583                     subscription.on_publish_ud = pubCallbackData;
584                     subscription.qos = qos;
585                     subscription.topic = topicFilterCur;
586 
587                     aws_array_list_push_back(&multiPub, reinterpret_cast<const void *>(&subscription));
588                 }
589 
590                 subAckCallbackData->connection = this;
591                 subAckCallbackData->allocator = m_owningClient->allocator;
592                 subAckCallbackData->onSubAck = std::move(onSubAck);
593                 subAckCallbackData->topic = nullptr;
594                 subAckCallbackData->allocator = m_owningClient->allocator;
595 
596                 packetId = aws_mqtt_client_connection_subscribe_multiple(
597                     m_underlyingConnection, &multiPub, s_onMultiSubAck, subAckCallbackData);
598 
599             clean_up:
600                 if (!packetId)
601                 {
602                     size_t length = aws_array_list_length(&multiPub);
603                     for (size_t i = 0; i < length; ++i)
604                     {
605                         aws_mqtt_topic_subscription *subscription = NULL;
606                         aws_array_list_get_at_ptr(&multiPub, reinterpret_cast<void **>(&subscription), i);
607                         auto pubCallbackData = reinterpret_cast<PubCallbackData *>(subscription->on_publish_ud);
608                         Crt::Delete(pubCallbackData, m_owningClient->allocator);
609                     }
610 
611                     Crt::Delete(subAckCallbackData, m_owningClient->allocator);
612                 }
613 
614                 aws_array_list_clean_up(&multiPub);
615 
616                 return packetId;
617             }
618 
Unsubscribe(const char * topicFilter,OnOperationCompleteHandler && onOpComplete)619             uint16_t MqttConnection::Unsubscribe(
620                 const char *topicFilter,
621                 OnOperationCompleteHandler &&onOpComplete) noexcept
622             {
623                 auto opCompleteCallbackData = Crt::New<OpCompleteCallbackData>(m_owningClient->allocator);
624 
625                 if (!opCompleteCallbackData)
626                 {
627                     return 0;
628                 }
629 
630                 opCompleteCallbackData->connection = this;
631                 opCompleteCallbackData->allocator = m_owningClient->allocator;
632                 opCompleteCallbackData->onOperationComplete = std::move(onOpComplete);
633                 opCompleteCallbackData->topic = nullptr;
634                 ByteBuf topicFilterBuf = aws_byte_buf_from_c_str(topicFilter);
635                 ByteCursor topicFilterCur = aws_byte_cursor_from_buf(&topicFilterBuf);
636 
637                 uint16_t packetId = aws_mqtt_client_connection_unsubscribe(
638                     m_underlyingConnection, &topicFilterCur, s_onOpComplete, opCompleteCallbackData);
639 
640                 if (!packetId)
641                 {
642                     Crt::Delete(opCompleteCallbackData, m_owningClient->allocator);
643                 }
644 
645                 return packetId;
646             }
647 
Publish(const char * topic,QOS qos,bool retain,const ByteBuf & payload,OnOperationCompleteHandler && onOpComplete)648             uint16_t MqttConnection::Publish(
649                 const char *topic,
650                 QOS qos,
651                 bool retain,
652                 const ByteBuf &payload,
653                 OnOperationCompleteHandler &&onOpComplete) noexcept
654             {
655 
656                 auto opCompleteCallbackData = Crt::New<OpCompleteCallbackData>(m_owningClient->allocator);
657                 if (!opCompleteCallbackData)
658                 {
659                     return 0;
660                 }
661 
662                 size_t topicLen = strlen(topic) + 1;
663                 char *topicCpy =
664                     reinterpret_cast<char *>(aws_mem_calloc(m_owningClient->allocator, topicLen, sizeof(char)));
665 
666                 if (!topicCpy)
667                 {
668                     Crt::Delete(opCompleteCallbackData, m_owningClient->allocator);
669                 }
670 
671                 memcpy(topicCpy, topic, topicLen);
672 
673                 opCompleteCallbackData->connection = this;
674                 opCompleteCallbackData->allocator = m_owningClient->allocator;
675                 opCompleteCallbackData->onOperationComplete = std::move(onOpComplete);
676                 opCompleteCallbackData->topic = topicCpy;
677                 ByteCursor topicCur = aws_byte_cursor_from_array(topicCpy, topicLen - 1);
678 
679                 ByteCursor payloadCur = aws_byte_cursor_from_buf(&payload);
680                 uint16_t packetId = aws_mqtt_client_connection_publish(
681                     m_underlyingConnection,
682                     &topicCur,
683                     qos,
684                     retain,
685                     &payloadCur,
686                     s_onOpComplete,
687                     opCompleteCallbackData);
688 
689                 if (!packetId)
690                 {
691                     aws_mem_release(m_owningClient->allocator, reinterpret_cast<void *>(topicCpy));
692                     Crt::Delete(opCompleteCallbackData, m_owningClient->allocator);
693                 }
694 
695                 return packetId;
696             }
697 
MqttClient(Io::ClientBootstrap & bootstrap,Allocator * allocator)698             MqttClient::MqttClient(Io::ClientBootstrap &bootstrap, Allocator *allocator) noexcept
699                 : m_client(aws_mqtt_client_new(allocator, bootstrap.GetUnderlyingHandle()))
700             {
701             }
702 
~MqttClient()703             MqttClient::~MqttClient()
704             {
705                 aws_mqtt_client_release(m_client);
706                 m_client = nullptr;
707             }
708 
MqttClient(MqttClient && toMove)709             MqttClient::MqttClient(MqttClient &&toMove) noexcept : m_client(toMove.m_client)
710             {
711                 toMove.m_client = nullptr;
712             }
713 
operator =(MqttClient && toMove)714             MqttClient &MqttClient::operator=(MqttClient &&toMove) noexcept
715             {
716                 if (&toMove != this)
717                 {
718                     m_client = toMove.m_client;
719                     toMove.m_client = nullptr;
720                 }
721 
722                 return *this;
723             }
724 
operator bool() const725             MqttClient::operator bool() const noexcept { return m_client != nullptr; }
726 
LastError() const727             int MqttClient::LastError() const noexcept { return aws_last_error(); }
728 
NewConnection(const char * hostName,uint16_t port,const Io::SocketOptions & socketOptions,const Crt::Io::TlsContext & tlsContext,bool useWebsocket)729             std::shared_ptr<MqttConnection> MqttClient::NewConnection(
730                 const char *hostName,
731                 uint16_t port,
732                 const Io::SocketOptions &socketOptions,
733                 const Crt::Io::TlsContext &tlsContext,
734                 bool useWebsocket) noexcept
735             {
736                 if (!tlsContext)
737                 {
738                     AWS_LOGF_ERROR(
739                         AWS_LS_MQTT_CLIENT,
740                         "id=%p Trying to call MqttClient::NewConnection using an invalid TlsContext.",
741                         (void *)m_client);
742                     aws_raise_error(AWS_ERROR_INVALID_ARGUMENT);
743                     return nullptr;
744                 }
745 
746                 // If you're reading this and asking.... why is this so complicated? Why not use make_shared
747                 // or allocate_shared? Well, MqttConnection constructors are private and stl is dumb like that.
748                 // so, we do it manually.
749                 Allocator *allocator = m_client->allocator;
750                 MqttConnection *toSeat =
751                     reinterpret_cast<MqttConnection *>(aws_mem_acquire(allocator, sizeof(MqttConnection)));
752                 if (!toSeat)
753                 {
754                     return nullptr;
755                 }
756 
757                 toSeat = new (toSeat) MqttConnection(m_client, hostName, port, socketOptions, tlsContext, useWebsocket);
758                 return std::shared_ptr<MqttConnection>(toSeat, [allocator](MqttConnection *connection) {
759                     connection->~MqttConnection();
760                     aws_mem_release(allocator, reinterpret_cast<void *>(connection));
761                 });
762             }
763 
NewConnection(const char * hostName,uint16_t port,const Io::SocketOptions & socketOptions,bool useWebsocket)764             std::shared_ptr<MqttConnection> MqttClient::NewConnection(
765                 const char *hostName,
766                 uint16_t port,
767                 const Io::SocketOptions &socketOptions,
768                 bool useWebsocket) noexcept
769 
770             {
771                 // If you're reading this and asking.... why is this so complicated? Why not use make_shared
772                 // or allocate_shared? Well, MqttConnection constructors are private and stl is dumb like that.
773                 // so, we do it manually.
774                 Allocator *allocator = m_client->allocator;
775                 MqttConnection *toSeat =
776                     reinterpret_cast<MqttConnection *>(aws_mem_acquire(m_client->allocator, sizeof(MqttConnection)));
777                 if (!toSeat)
778                 {
779                     return nullptr;
780                 }
781 
782                 toSeat = new (toSeat) MqttConnection(m_client, hostName, port, socketOptions, useWebsocket);
783                 return std::shared_ptr<MqttConnection>(toSeat, [allocator](MqttConnection *connection) {
784                     connection->~MqttConnection();
785                     aws_mem_release(allocator, reinterpret_cast<void *>(connection));
786                 });
787             }
788         } // namespace Mqtt
789     }     // namespace Crt
790 } // namespace Aws
791