1 /**
2  * Copyright (c) 2017, Andrew Gault, Nick Chadwick and Guillaume Egles.
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *    * Redistributions of source code must retain the above copyright
8  *      notice, this list of conditions and the following disclaimer.
9  *    * Redistributions in binary form must reproduce the above copyright
10  *      notice, this list of conditions and the following disclaimer in the
11  *      documentation and/or other materials provided with the distribution.
12  *    * Neither the name of the <organization> nor the
13  *      names of its contributors may be used to endorse or promote products
14  *      derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE
20  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 /**
29  * Wrapper around usrsctp/
30  */
31 
32 #include "rtcdcpp/SCTPWrapper.hpp"
33 
34 #include <iostream>
35 
36 namespace rtcdcpp {
37 
38 using namespace std;
39 
SCTPWrapper(DTLSEncryptCallbackPtr dtlsEncryptCB,MsgReceivedCallbackPtr msgReceivedCB)40 SCTPWrapper::SCTPWrapper(DTLSEncryptCallbackPtr dtlsEncryptCB, MsgReceivedCallbackPtr msgReceivedCB)
41     : local_port(5000),  // XXX: Hard-coded for now
42       remote_port(5000),
43       stream_cursor(0),
44       dtlsEncryptCallback(dtlsEncryptCB),
45       msgReceivedCallback(msgReceivedCB) {}
46 
~SCTPWrapper()47 SCTPWrapper::~SCTPWrapper() {
48   Stop();
49 
50   int tries = 0;
51   while (usrsctp_finish() != 0 && tries < 5) {
52     std::this_thread::sleep_for(std::chrono::milliseconds(1000));
53   }
54 }
55 
56 static uint16_t interested_events[] = {SCTP_ASSOC_CHANGE,         SCTP_PEER_ADDR_CHANGE,   SCTP_REMOTE_ERROR,          SCTP_SEND_FAILED,
57                                        SCTP_SENDER_DRY_EVENT,     SCTP_SHUTDOWN_EVENT,     SCTP_ADAPTATION_INDICATION, SCTP_PARTIAL_DELIVERY_EVENT,
58                                        SCTP_AUTHENTICATION_EVENT, SCTP_STREAM_RESET_EVENT, SCTP_ASSOC_RESET_EVENT,     SCTP_STREAM_CHANGE_EVENT,
59                                        SCTP_SEND_FAILED_EVENT};
60 
61 // TODO: error callbacks
OnNotification(union sctp_notification * notify,size_t len)62 void SCTPWrapper::OnNotification(union sctp_notification *notify, size_t len) {
63   if (notify->sn_header.sn_length != (uint32_t)len) {
64     logger->error("OnNotification(len={}) invalid length: {}", len, notify->sn_header.sn_length);
65     return;
66   }
67 
68   switch (notify->sn_header.sn_type) {
69     case SCTP_ASSOC_CHANGE:
70       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_ASSOC_CHANGE)");
71       break;
72     case SCTP_PEER_ADDR_CHANGE:
73       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_PEER_ADDR_CHANGE)");
74       break;
75     case SCTP_REMOTE_ERROR:
76       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_REMOTE_ERROR)");
77       break;
78     case SCTP_SEND_FAILED_EVENT:
79       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_SEND_FAILED_EVENT)");
80       break;
81     case SCTP_SHUTDOWN_EVENT:
82       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_SHUTDOWN_EVENT)");
83       break;
84     case SCTP_ADAPTATION_INDICATION:
85       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_ADAPTATION_INDICATION)");
86       break;
87     case SCTP_PARTIAL_DELIVERY_EVENT:
88       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_PARTIAL_DELIVERY_EVENT)");
89       break;
90     case SCTP_AUTHENTICATION_EVENT:
91       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_AUTHENTICATION_EVENT)");
92       break;
93     case SCTP_SENDER_DRY_EVENT:
94       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_SENDER_DRY_EVENT)");
95       break;
96     case SCTP_NOTIFICATIONS_STOPPED_EVENT:
97       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_NOTIFICATIONS_STOPPED_EVENT)");
98       break;
99     case SCTP_STREAM_RESET_EVENT:
100       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_STREAM_RESET_EVENT)");
101       break;
102     case SCTP_ASSOC_RESET_EVENT:
103       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_ASSOC_RESET_EVENT)");
104       break;
105     case SCTP_STREAM_CHANGE_EVENT:
106       SPDLOG_TRACE(logger, "OnNotification(type=SCTP_STREAM_CHANGE_EVENT)");
107       break;
108     default:
109       SPDLOG_TRACE(logger, "OnNotification(type={} (unknown))", notify->sn_header.sn_type);
110       break;
111   }
112 }
113 
_OnSCTPForDTLS(void * sctp_ptr,void * data,size_t len,uint8_t tos,uint8_t set_df)114 int SCTPWrapper::_OnSCTPForDTLS(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) {
115   if (sctp_ptr) {
116     return static_cast<SCTPWrapper *>(sctp_ptr)->OnSCTPForDTLS(data, len, tos, set_df);
117   } else {
118     return -1;
119   }
120 }
121 
OnSCTPForDTLS(void * data,size_t len,uint8_t tos,uint8_t set_df)122 int SCTPWrapper::OnSCTPForDTLS(void *data, size_t len, uint8_t tos, uint8_t set_df) {
123   SPDLOG_TRACE(logger, "Data ready. len={}, tos={}, set_df={}", len, tos, set_df);
124   this->dtlsEncryptCallback(std::make_shared<Chunk>(data, len));
125 
126   {
127     unique_lock<mutex> l(connectMtx);
128     this->connectSentData = true;
129     connectCV.notify_one();
130   }
131 
132   return 0;  // success
133 }
134 
_DebugLog(const char * format,...)135 void SCTPWrapper::_DebugLog(const char *format, ...) {
136   va_list ap;
137   va_start(ap, format);
138   // std::string msg = Util::FormatString(format, ap);
139   char msg[1024 * 16];
140   vsprintf(msg, format, ap);
141   GetLogger("librtcpp.SCTP")->trace("SCTP: msg={}", msg);
142   va_end(ap);
143 }
144 
_OnSCTPForGS(struct socket * sock,union sctp_sockstore addr,void * data,size_t len,struct sctp_rcvinfo recv_info,int flags,void * user_data)145 int SCTPWrapper::_OnSCTPForGS(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, struct sctp_rcvinfo recv_info, int flags,
146                               void *user_data) {
147   if (user_data) {
148     return static_cast<SCTPWrapper *>(user_data)->OnSCTPForGS(sock, addr, data, len, recv_info, flags);
149   } else {
150     return -1;
151   }
152 }
153 
OnSCTPForGS(struct socket * sock,union sctp_sockstore addr,void * data,size_t len,struct sctp_rcvinfo recv_info,int flags)154 int SCTPWrapper::OnSCTPForGS(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, struct sctp_rcvinfo recv_info, int flags) {
155   if (len == 0) {
156     return -1;
157   }
158 
159   SPDLOG_TRACE(logger, "Data received. stream={}, len={}, SSN={}, TSN={}, PPID={}",
160                 len,
161                 recv_info.rcv_sid,
162                 recv_info.rcv_ssn,
163                 recv_info.rcv_tsn,
164                 ntohl(recv_info.rcv_ppid));
165 
166   if (flags & MSG_NOTIFICATION) {
167     OnNotification((union sctp_notification *)data, len);
168   } else {
169     std::cout << "Got msg of size: " << len << "\n";
170     OnMsgReceived((const uint8_t *)data, len, recv_info.rcv_sid, ntohl(recv_info.rcv_ppid));
171   }
172   free(data);
173   return 0;
174 }
175 
OnMsgReceived(const uint8_t * data,size_t len,int ppid,int sid)176 void SCTPWrapper::OnMsgReceived(const uint8_t *data, size_t len, int ppid, int sid) {
177   this->msgReceivedCallback(std::make_shared<Chunk>(data, len), ppid, sid);
178 }
179 
Initialize()180 bool SCTPWrapper::Initialize() {
181   usrsctp_init(0, &SCTPWrapper::_OnSCTPForDTLS, &SCTPWrapper::_DebugLog);
182   usrsctp_sysctl_set_sctp_ecn_enable(0);
183   usrsctp_register_address(this);
184 
185   sock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SCTPWrapper::_OnSCTPForGS, NULL, 0, this);
186   if (!sock) {
187     logger->error("Could not create usrsctp_socket. errno={}", errno);
188     return false;
189   }
190 
191   struct linger linger_opt;
192   linger_opt.l_onoff = 1;
193   linger_opt.l_linger = 0;
194   if (usrsctp_setsockopt(this->sock, SOL_SOCKET, SO_LINGER, &linger_opt, sizeof(linger_opt)) == -1) {
195     logger->error("Could not set socket options for SO_LINGER. errno={}", errno);
196     return false;
197   }
198 
199   struct sctp_paddrparams peer_param;
200   memset(&peer_param, 0, sizeof(peer_param));
201   peer_param.spp_flags = SPP_PMTUD_DISABLE;
202   peer_param.spp_pathmtu = 1200;  // XXX: Does this need to match the actual MTU?
203   if (usrsctp_setsockopt(this->sock, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peer_param, sizeof(peer_param)) == -1) {
204     logger->error("Could not set socket options for SCTP_PEER_ADDR_PARAMS. errno={}", errno);
205     return false;
206   }
207 
208   struct sctp_assoc_value av;
209   av.assoc_id = SCTP_ALL_ASSOC;
210   av.assoc_value = 1;
211   if (usrsctp_setsockopt(this->sock, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)) == -1) {
212     logger->error("Could not set socket options for SCTP_ENABLE_STREAM_RESET. errno={}", errno);
213     return false;
214   }
215 
216   uint32_t nodelay = 1;
217   if (usrsctp_setsockopt(this->sock, IPPROTO_SCTP, SCTP_NODELAY, &nodelay, sizeof(nodelay)) == -1) {
218     logger->error("Could not set socket options for SCTP_NODELAY. errno={}", errno);
219     return false;
220   }
221 
222   /* Enable the events of interest */
223   struct sctp_event event;
224   memset(&event, 0, sizeof(event));
225   event.se_assoc_id = SCTP_ALL_ASSOC;
226   event.se_on = 1;
227   int num_events = sizeof(interested_events) / sizeof(uint16_t);
228   for (int i = 0; i < num_events; i++) {
229     event.se_type = interested_events[i];
230     if (usrsctp_setsockopt(this->sock, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event)) == -1) {
231       logger->error("Could not set socket options for SCTP_EVENT {}. errno={}", i, errno);
232       return false;
233     }
234   }
235 
236   struct sctp_initmsg init_msg;
237   memset(&init_msg, 0, sizeof(init_msg));
238   init_msg.sinit_num_ostreams = MAX_OUT_STREAM;
239   init_msg.sinit_max_instreams = MAX_IN_STREAM;
240   if (usrsctp_setsockopt(this->sock, IPPROTO_SCTP, SCTP_INITMSG, &init_msg, sizeof(init_msg)) == -1) {
241     logger->error("Could not set socket options for SCTP_INITMSG. errno={}", errno);
242     return false;
243   }
244 
245   struct sockaddr_conn sconn;
246   sconn.sconn_family = AF_CONN;
247   sconn.sconn_port = htons(remote_port);
248   sconn.sconn_addr = (void *)this;
249 #ifdef HAVE_SCONN_LEN
250   sconn.sconn_len = sizeof(struct sockaddr_conn);
251 #endif
252 
253   if (usrsctp_bind(this->sock, (struct sockaddr *)&sconn, sizeof(sconn)) == -1) {
254     logger->error("Could not usrsctp_bind. errno={}", errno);
255     return false;
256   }
257 
258   return true;
259 }
260 
Start()261 void SCTPWrapper::Start() {
262   if (started) {
263     logger->error("Start() - already started!");
264     return;
265   }
266 
267   SPDLOG_TRACE(logger, "Start()");
268   started = true;
269 
270   this->recv_thread = std::thread(&SCTPWrapper::RecvLoop, this);
271   this->connect_thread = std::thread(&SCTPWrapper::RunConnect, this);
272 }
273 
Stop()274 void SCTPWrapper::Stop() {
275   this->should_stop = true;
276 
277   send_queue.Stop();
278   recv_queue.Stop();
279 
280   connectCV.notify_one();  // unblock the recv thread in case we never connected
281   if (this->recv_thread.joinable()) {
282     this->recv_thread.join();
283   }
284 
285   if (this->connect_thread.joinable()) {
286     this->connect_thread.join();
287   }
288 
289   if (sock) {
290     usrsctp_shutdown(sock, SHUT_RDWR);
291     usrsctp_close(sock);
292     sock = nullptr;
293   }
294 }
295 
DTLSForSCTP(ChunkPtr chunk)296 void SCTPWrapper::DTLSForSCTP(ChunkPtr chunk) { this->recv_queue.push(chunk); }
297 
298 // Send a message to the remote connection
GSForSCTP(ChunkPtr chunk,uint16_t sid,uint32_t ppid)299 void SCTPWrapper::GSForSCTP(ChunkPtr chunk, uint16_t sid, uint32_t ppid) {
300   struct sctp_sendv_spa spa = {0};
301 
302   // spa.sendv_flags = SCTP_SEND_SNDINFO_VALID | SCTP_SEND_PRINFO_VALID;
303   spa.sendv_flags = SCTP_SEND_SNDINFO_VALID;
304 
305   spa.sendv_sndinfo.snd_sid = sid;
306   // spa.sendv_sndinfo.snd_flags = SCTP_EOR | SCTP_UNORDERED;
307   spa.sendv_sndinfo.snd_flags = SCTP_EOR;
308   spa.sendv_sndinfo.snd_ppid = htonl(ppid);
309 
310   // spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;
311   // spa.sendv_prinfo.pr_value = 0;
312 
313   int tries = 0;
314   while (tries < 5) {
315     if (usrsctp_sendv(this->sock, chunk->Data(), chunk->Length(), NULL, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0) < 0) {
316       logger->error("FAILED to send, try: {}", tries);
317       tries += 1;
318       std::this_thread::sleep_for(std::chrono::seconds(tries));
319     } else {
320       return;
321     }
322   }
323   //tried about 5 times and still no luck
324   throw std::runtime_error("Send failed");
325 }
326 
RecvLoop()327 void SCTPWrapper::RecvLoop() {
328   // Util::SetThreadName("SCTP-RecvLoop");
329 //  NDC ndc("SCTP-RecvLoop");
330 
331   SPDLOG_TRACE(logger, "RunRecv()");
332 
333   {
334     // We need to wait for the connect thread to send some data
335     unique_lock<mutex> l(connectMtx);
336     while (!this->connectSentData && !this->should_stop) {
337       connectCV.wait_for(l, chrono::milliseconds(100));
338     }
339   }
340 
341   SPDLOG_DEBUG(logger, "RunRecv() sent_data=true");
342 
343   while (!this->should_stop) {
344     ChunkPtr chunk = this->recv_queue.wait_and_pop();
345     if (!chunk) {
346       return;
347     }
348     SPDLOG_DEBUG(logger, "RunRecv() Handling packet of len - {}", chunk->Length());
349     usrsctp_conninput(this, chunk->Data(), chunk->Length(), 0);
350   }
351 }
352 
RunConnect()353 void SCTPWrapper::RunConnect() {
354   // Util::SetThreadName("SCTP-Connect");
355   SPDLOG_TRACE(logger, "RunConnect() port={}", remote_port);
356 
357   struct sockaddr_conn sconn;
358   sconn.sconn_family = AF_CONN;
359   sconn.sconn_port = htons(remote_port);
360   sconn.sconn_addr = (void *)this;
361 #ifdef HAVE_SCONN_LEN
362   sconn.sconn_len = sizeof((void *)this);
363 #endif
364 
365   // Blocks until connection succeeds/fails
366   int connect_result = usrsctp_connect(sock, (struct sockaddr *)&sconn, sizeof sconn);
367 
368   if ((connect_result < 0) && (errno != EINPROGRESS)) {
369     SPDLOG_DEBUG(logger, "Connection failed. errno={}", errno);
370     should_stop = true;
371 
372     {
373       // Unblock the recv thread
374       unique_lock<mutex> l(connectMtx);
375       connectCV.notify_one();
376     }
377 
378     // TODO let the world know we failed :(
379 
380   } else {
381     SPDLOG_DEBUG(logger, "Connected on port {}", remote_port);
382   }
383 }
384 }
385