1 /*
2  * Authored by Alex Hultman, 2018-2021.
3  * Intellectual property of third-party.
4 
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8 
9  *     http://www.apache.org/licenses/LICENSE-2.0
10 
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #ifndef UWS_WEBSOCKET_H
19 #define UWS_WEBSOCKET_H
20 
21 #include "WebSocketData.h"
22 #include "WebSocketProtocol.h"
23 #include "AsyncSocket.h"
24 #include "WebSocketContextData.h"
25 
26 #include <string_view>
27 
28 namespace uWS {
29 
30 template <bool SSL, bool isServer, typename USERDATA>
31 struct WebSocket : AsyncSocket<SSL> {
32     template <bool> friend struct TemplatedApp;
33     template <bool> friend struct HttpResponse;
34 private:
35     typedef AsyncSocket<SSL> Super;
36 
initWebSocket37     void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) {
38         new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure));
39         return this;
40     }
41 public:
42 
43     /* Returns pointer to the per socket user data */
getUserDataWebSocket44     USERDATA *getUserData() {
45         WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
46         /* We just have it overallocated by sizeof type */
47         return (USERDATA *) (webSocketData + 1);
48     }
49 
50     /* See AsyncSocket */
51     using Super::getBufferedAmount;
52     using Super::getRemoteAddress;
53     using Super::getRemoteAddressAsText;
54     using Super::getNativeHandle;
55 
56     /* WebSocket close cannot be an alias to AsyncSocket::close since
57      * we need to check first if it was shut down by remote peer */
closeWebSocket58     us_socket_t *close() {
59         if (us_socket_is_closed(SSL, (us_socket_t *) this)) {
60             return nullptr;
61         }
62         WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
63         if (webSocketData->isShuttingDown) {
64             return nullptr;
65         }
66 
67         return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr);
68     }
69 
70     enum SendStatus : int {
71         BACKPRESSURE,
72         SUCCESS,
73         DROPPED
74     };
75 
76     /* Sending fragmented messages puts a bit of effort on the user; you must not interleave regular sends
77      * with fragmented sends and you must sendFirstFragment, [sendFragment], then finally sendLastFragment. */
78     SendStatus sendFirstFragment(std::string_view message, OpCode opCode = OpCode::BINARY, bool compress = false) {
79         return send(message, opCode, compress, false);
80     }
81 
82     SendStatus sendFragment(std::string_view message, bool compress = false) {
83         return send(message, CONTINUATION, compress, false);
84     }
85 
86     SendStatus sendLastFragment(std::string_view message, bool compress = false) {
87         return send(message, CONTINUATION, compress, true);
88     }
89 
90     /* Send or buffer a WebSocket frame, compressed or not. Returns BACKPRESSURE on increased user space backpressure,
91      * DROPPED on dropped message (due to backpressure) or SUCCCESS if you are free to send even more now. */
92     SendStatus send(std::string_view message, OpCode opCode = OpCode::BINARY, bool compress = false, bool fin = true) {
93         WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
94             (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
95         );
96 
97         /* Skip sending and report success if we are over the limit of maxBackpressure */
98         if (webSocketContextData->maxBackpressure && webSocketContextData->maxBackpressure < getBufferedAmount()) {
99             /* Also defer a close if we should */
100             if (webSocketContextData->closeOnBackpressureLimit) {
101                 us_socket_shutdown_read(SSL, (us_socket_t *) this);
102             }
103             return DROPPED;
104         }
105 
106         /* If we are subscribers and have messages to drain we need to drain them here to stay synced */
107         WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
108         if (webSocketData->subscriber) {
109             /* This will call back into us, send. */
110             webSocketContextData->topicTree->drain(webSocketData->subscriber);
111         }
112 
113         /* Transform the message to compressed domain if requested */
114         if (compress) {
115             WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
116 
117             /* Check and correct the compress hint. It is never valid to compress 0 bytes */
118             if (message.length() && opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) {
119                 LoopData *loopData = Super::getLoopData();
120                 /* Compress using either shared or dedicated deflationStream */
121                 if (webSocketData->deflationStream) {
122                     message = webSocketData->deflationStream->deflate(loopData->zlibContext, message, false);
123                 } else {
124                     message = loopData->deflationStream->deflate(loopData->zlibContext, message, true);
125                 }
126             } else {
127                 compress = false;
128             }
129         }
130 
131         /* Get size, allocate size, write if needed */
132         size_t messageFrameSize = protocol::messageFrameSize(message.length());
133         auto [sendBuffer, sendBufferAttribute] = Super::getSendBuffer(messageFrameSize);
134         protocol::formatMessage<isServer>(sendBuffer, message.data(), message.length(), opCode, message.length(), compress, fin);
135 
136         /* Depending on size of message we have different paths */
137         if (sendBufferAttribute == SendBufferAttribute::NEEDS_DRAIN) {
138             /* This is a drain */
139             auto[written, failed] = Super::write(nullptr, 0);
140             if (failed) {
141                 /* Return false for failure, skipping to reset the timeout below */
142                 return BACKPRESSURE;
143             }
144         } else if (sendBufferAttribute == SendBufferAttribute::NEEDS_UNCORK) {
145             /* Uncork if we came here uncorked */
146             auto [written, failed] = Super::uncork();
147             if (failed) {
148                 return BACKPRESSURE;
149             }
150         }
151 
152         /* Every successful send resets the timeout */
153         if (webSocketContextData->resetIdleTimeoutOnSend) {
154             Super::timeout(webSocketContextData->idleTimeoutComponents.first);
155             WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
156             webSocketData->hasTimedOut = false;
157         }
158 
159         /* Return success */
160         return SUCCESS;
161     }
162 
163     /* Send websocket close frame, emit close event, send FIN if successful.
164      * Will not append a close reason if code is 0 or 1005. */
165     void end(int code = 0, std::string_view message = {}) {
166         /* Check if we already called this one */
167         WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
168         if (webSocketData->isShuttingDown) {
169             return;
170         }
171 
172         /* We postpone any FIN sending to either drainage or uncorking */
173         webSocketData->isShuttingDown = true;
174 
175         /* Format and send the close frame */
176         static const int MAX_CLOSE_PAYLOAD = 123;
177         size_t length = std::min<size_t>(MAX_CLOSE_PAYLOAD, message.length());
178         char closePayload[MAX_CLOSE_PAYLOAD + 2];
179         size_t closePayloadLength = protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length);
180         bool ok = send(std::string_view(closePayload, closePayloadLength), OpCode::CLOSE);
181 
182         /* FIN if we are ok and not corked */
183         if (!this->isCorked()) {
184             if (ok) {
185                 /* If we are not corked, and we just sent off everything, we need to FIN right here.
186                  * In all other cases, we need to fin either if uncork was successful, or when drainage is complete. */
187                 this->shutdown();
188             }
189         }
190 
191         WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
192             (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
193         );
194 
195         /* Set shorter timeout (use ping-timeout) to avoid long hanging sockets after end() on broken connections */
196         Super::timeout(webSocketContextData->idleTimeoutComponents.second);
197 
198         /* Emit close event */
199         if (webSocketContextData->closeHandler) {
200             webSocketContextData->closeHandler(this, code, message);
201         }
202 
203         /* Make sure to unsubscribe from any pub/sub node at exit */
204         webSocketContextData->topicTree->freeSubscriber(webSocketData->subscriber);
205         webSocketData->subscriber = nullptr;
206     }
207 
208     /* Corks the response if possible. Leaves already corked socket be. */
corkWebSocket209     void cork(MoveOnlyFunction<void()> &&handler) {
210         if (!Super::isCorked() && Super::canCork()) {
211             Super::cork();
212             handler();
213 
214             /* There is no timeout when failing to uncork for WebSockets,
215              * as that is handled by idleTimeout */
216             auto [written, failed] = Super::uncork();
217         } else {
218             /* We are already corked, or can't cork so let's just call the handler */
219             handler();
220         }
221     }
222 
223     /* Subscribe to a topic according to MQTT rules and syntax. Returns success */
224     bool subscribe(std::string_view topic, bool = false) {
225         WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
226             (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
227         );
228 
229         /* Make us a subscriber if we aren't yet */
230         WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
231         if (!webSocketData->subscriber) {
232             webSocketData->subscriber = webSocketContextData->topicTree->createSubscriber();
233             webSocketData->subscriber->user = this;
234         }
235 
236         /* Cannot return numSubscribers as this is only for this particular websocket context */
237         webSocketContextData->topicTree->subscribe(webSocketData->subscriber, topic);
238 
239         /* Subscribe always succeeds */
240         return true;
241     }
242 
243     /* Unsubscribe from a topic, returns true if we were subscribed. */
244     bool unsubscribe(std::string_view topic, bool = false) {
245         WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
246             (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
247         );
248 
249         WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
250 
251         /* Cannot return numSubscribers as this is only for this particular websocket context */
252         auto [ok, last] = webSocketContextData->topicTree->unsubscribe(webSocketData->subscriber, topic);
253 
254         /* Free us as subscribers if we unsubscribed from our last topic */
255         if (ok && last) {
256             webSocketContextData->topicTree->freeSubscriber(webSocketData->subscriber);
257             webSocketData->subscriber = nullptr;
258         }
259 
260         return ok;
261     }
262 
263     /* Returns whether this socket is subscribed to the specified topic */
isSubscribedWebSocket264     bool isSubscribed(std::string_view topic) {
265         WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
266             (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
267         );
268 
269         WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
270         if (!webSocketData->subscriber) {
271             return false;
272         }
273 
274         Topic *topicPtr = webSocketContextData->topicTree->lookupTopic(topic);
275         if (!topicPtr) {
276             return false;
277         }
278 
279         return topicPtr->count(webSocketData->subscriber);
280     }
281 
282     /* Iterates all topics of this WebSocket. Every topic is represented by its full name.
283      * Can be called in close handler. It is possible to modify the subscription list while
284      * inside the callback ONLY IF not modifying the topic passed to the callback.
285      * Topic names are valid only for the duration of the callback. */
iterateTopicsWebSocket286     void iterateTopics(MoveOnlyFunction<void(std::string_view)> cb) {
287         WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
288             (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
289         );
290 
291         WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
292         if (webSocketData->subscriber) {
293             /* Lock this subscriber for unsubscription / subscription */
294             webSocketContextData->topicTree->iteratingSubscriber = webSocketData->subscriber;
295 
296             for (Topic *topicPtr : webSocketData->subscriber->topics) {
297                 cb({topicPtr->name.data(), topicPtr->name.length()});
298             }
299 
300             /* Unlock subscriber */
301             webSocketContextData->topicTree->iteratingSubscriber = nullptr;
302         }
303     }
304 
305     /* Publish a message to a topic according to MQTT rules and syntax. Returns success.
306      * We, the WebSocket, must be subscribed to the topic itself and if so - no message will be sent to ourselves.
307      * Use App::publish for an unconditional publish that simply publishes to whomever might be subscribed. */
308     bool publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) {
309         WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
310             (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
311         );
312 
313         /* We cannot be a subscriber of this topic if we are not a subscriber of anything */
314         WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
315         if (!webSocketData->subscriber) {
316             /* Failure, but still do return the number of subscribers */
317             return false;
318         }
319 
320         /* Publish as sender, does not receive its own messages even if subscribed to relevant topics */
321         if (message.length() >= LoopData::CORK_BUFFER_SIZE) {
322             return webSocketContextData->topicTree->publishBig(webSocketData->subscriber, topic, {message, opCode, compress}, [](Subscriber *s, TopicTreeBigMessage &message) {
323                 auto *ws = (WebSocket<SSL, true, int> *) s->user;
324 
325                 ws->send(message.message, (OpCode)message.opCode, message.compress);
326             });
327         } else {
328             return webSocketContextData->topicTree->publish(webSocketData->subscriber, topic, {std::string(message), opCode, compress});
329         }
330     }
331 };
332 
333 }
334 
335 #endif // UWS_WEBSOCKET_H
336