1 /*
2  * Copyright (C) 2001-2012 Jacek Sieka, arnetheduck on gmail point com
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; either version 2 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17  */
18 
19 #include "stdinc.h"
20 
21 #include "BufferedSocket.h"
22 
23 #include "TimerManager.h"
24 #include "SettingsManager.h"
25 
26 #include "Streams.h"
27 #include "SSLSocket.h"
28 #include "CryptoManager.h"
29 #include "ZUtils.h"
30 
31 #include "ThrottleManager.h"
32 
33 namespace dcpp {
34 
35 // Polling is used for tasks...should be fixed...
36 #define POLL_TIMEOUT 250
37 
BufferedSocket(char aSeparator)38 BufferedSocket::BufferedSocket(char aSeparator) :
39 separator(aSeparator), mode(MODE_LINE), dataBytes(0), rollback(0), state(STARTING),
40 disconnecting(false)
41 {
42     start();
43 
44     sockets.inc();
45 }
46 
47 Atomic<long,memory_ordering_strong> BufferedSocket::sockets(0);
48 
~BufferedSocket()49 BufferedSocket::~BufferedSocket() {
50     sockets.dec();
51 }
52 
setMode(Modes aMode,size_t aRollback)53 void BufferedSocket::setMode (Modes aMode, size_t aRollback) {
54     if (mode == aMode) {
55         dcdebug ("WARNING: Re-entering mode %d\n", mode);
56         return;
57     }
58 
59     switch (aMode) {
60         case MODE_LINE:
61             rollback = aRollback;
62             break;
63         case MODE_ZPIPE:
64             filterIn = std::unique_ptr<UnZFilter>(new UnZFilter);
65             break;
66         case MODE_DATA:
67             break;
68     }
69     mode = aMode;
70 }
71 
setSocket(std::unique_ptr<Socket> s)72 void BufferedSocket::setSocket(std::unique_ptr<Socket> s) {
73     dcassert(!sock.get());
74     if(SETTING(SOCKET_IN_BUFFER) > 0)
75         s->setSocketOpt(SO_RCVBUF, SETTING(SOCKET_IN_BUFFER));
76     if(SETTING(SOCKET_OUT_BUFFER) > 0)
77         s->setSocketOpt(SO_SNDBUF, SETTING(SOCKET_OUT_BUFFER));
78     s->setSocketOpt(SO_REUSEADDR, 1);   // NAT traversal
79 
80     inbuf.resize(s->getSocketOptInt(SO_RCVBUF));
81 
82     sock = move(s);
83 }
84 
accept(const Socket & srv,bool secure,bool allowUntrusted)85 void BufferedSocket::accept(const Socket& srv, bool secure, bool allowUntrusted) {
86     dcdebug("BufferedSocket::accept() %p\n", (void*)this);
87 
88     std::unique_ptr<Socket> s(secure ? CryptoManager::getInstance()->getServerSocket(allowUntrusted) : new Socket);
89 
90     s->accept(srv);
91 
92     setSocket(move(s));
93 
94     Lock l(cs);
95     addTask(ACCEPTED, 0);
96 }
97 
connect(const string & aAddress,uint16_t aPort,bool secure,bool allowUntrusted,bool proxy)98 void BufferedSocket::connect(const string& aAddress, uint16_t aPort, bool secure, bool allowUntrusted, bool proxy) {
99     connect(aAddress, aPort, 0, NAT_NONE, secure, allowUntrusted, proxy);
100 }
101 
connect(const string & aAddress,uint16_t aPort,uint16_t localPort,NatRoles natRole,bool secure,bool allowUntrusted,bool proxy)102 void BufferedSocket::connect(const string& aAddress, uint16_t aPort, uint16_t localPort, NatRoles natRole, bool secure, bool allowUntrusted, bool proxy) {
103     dcdebug("BufferedSocket::connect() %p\n", (void*)this);
104     std::unique_ptr<Socket> s(secure ? (natRole == NAT_SERVER ? CryptoManager::getInstance()->getServerSocket(allowUntrusted) : CryptoManager::getInstance()->getClientSocket(allowUntrusted)) : new Socket);
105 
106     s->create();
107     setSocket(move(s));
108     sock->bind(localPort, SETTING(BIND_IFACE)? sock->getIfaceI4(SETTING(BIND_IFACE_NAME)).c_str() : SETTING(BIND_ADDRESS));
109 
110     Lock l(cs);
111     addTask(CONNECT, new ConnectInfo(aAddress, aPort, localPort, natRole, proxy && (SETTING(OUTGOING_CONNECTIONS) == SettingsManager::OUTGOING_SOCKS5)));
112 }
113 
114 #define LONG_TIMEOUT 30000
115 #define SHORT_TIMEOUT 1000
threadConnect(const string & aAddr,uint16_t aPort,uint16_t localPort,NatRoles natRole,bool proxy)116 void BufferedSocket::threadConnect(const string& aAddr, uint16_t aPort, uint16_t localPort, NatRoles natRole, bool proxy) {
117     dcassert(state == STARTING);
118 
119     dcdebug("threadConnect %s:%d/%d\n", aAddr.c_str(), (int)localPort, (int)aPort);
120     fire(BufferedSocketListener::Connecting());
121 
122     const uint64_t endTime = GET_TICK() + LONG_TIMEOUT;
123     state = RUNNING;
124 
125     while (GET_TICK() < endTime) {
126         dcdebug("threadConnect attempt to addr \"%s\"\n", aAddr.c_str());
127         try {
128             if(proxy) {
129                 sock->socksConnect(aAddr, aPort, LONG_TIMEOUT);
130             } else {
131                 sock->connect(aAddr, aPort);
132             }
133 
134             bool connSucceeded;
135             while(!(connSucceeded = sock->waitConnected(POLL_TIMEOUT)) && endTime >= GET_TICK()) {
136                 if(disconnecting) return;
137             }
138 
139             if (connSucceeded) {
140                 fire(BufferedSocketListener::Connected());
141                 return;
142             }
143         }
144         catch (const SSLSocketException&) {
145             throw;
146         } catch (const SocketException&) {
147             if (natRole == NAT_NONE)
148                 throw;
149             Thread::sleep(SHORT_TIMEOUT);
150         }
151     }
152 
153     throw SocketException(_("Connection timeout"));
154 }
155 
threadAccept()156 void BufferedSocket::threadAccept() {
157     dcassert(state == STARTING);
158 
159     dcdebug("threadAccept\n");
160 
161     state = RUNNING;
162 
163     uint64_t startTime = GET_TICK();
164     while(!sock->waitAccepted(POLL_TIMEOUT)) {
165         if(disconnecting)
166             return;
167 
168         if((startTime + 30000) < GET_TICK()) {
169             throw SocketException(_("Connection timeout"));
170         }
171     }
172 }
173 
threadRead()174 void BufferedSocket::threadRead() {
175     if(state != RUNNING)
176         return;
177 
178     int left = (mode == MODE_DATA) ? ThrottleManager::getInstance()->read(sock.get(), &inbuf[0], (int)inbuf.size()) : sock->read(&inbuf[0], (int)inbuf.size());
179     if(left == -1) {
180         // EWOULDBLOCK, no data received...
181         return;
182     } else if(left == 0) {
183         // This socket has been closed...
184         throw SocketException(_("Connection closed"));
185     }
186     string::size_type pos = 0;
187     // always uncompressed data
188     string l;
189     int bufpos = 0, total = left;
190 
191     while (left > 0) {
192         switch (mode) {
193             case MODE_ZPIPE: {
194                     const int BUF_SIZE = 1024;
195                     // Special to autodetect nmdc connections...
196                     string::size_type pos = 0;
197                     boost::scoped_array<char> buffer(new char[BUF_SIZE]);
198                     l = line;
199                     // decompress all input data and store in l.
200                     while (left) {
201                         size_t in = BUF_SIZE;
202                         size_t used = left;
203                         bool ret = (*filterIn) (&inbuf[0] + total - left, used, &buffer[0], in);
204                         left -= used;
205                         l.append (&buffer[0], in);
206                         // if the stream ends before the data runs out, keep remainder of data in inbuf
207                         if (!ret) {
208                             bufpos = total-left;
209                             setMode (MODE_LINE, rollback);
210                             break;
211                         }
212                     }
213                     // process all lines
214                     while ((pos = l.find(separator)) != string::npos) {
215                         if(pos > 0) // check empty (only pipe) command and don't waste cpu with it ;o)
216                             fire(BufferedSocketListener::Line(), l.substr(0, pos));
217                         l.erase (0, pos + 1 /* separator char */);
218                     }
219                     // store remainder
220                     line = l;
221 
222                     break;
223                 }
224             case MODE_LINE:
225                 // Special to autodetect nmdc connections...
226                 if(separator == 0) {
227                     if(inbuf[0] == '$') {
228                         separator = '|';
229                     } else {
230                         separator = '\n';
231                     }
232                 }
233                 l = line + string ((char*)&inbuf[bufpos], left);
234                 while ((pos = l.find(separator)) != string::npos) {
235                     if(pos > 0) // check empty (only pipe) command and don't waste cpu with it ;o)
236                         fire(BufferedSocketListener::Line(), l.substr(0, pos));
237                     l.erase (0, pos + 1 /* separator char */);
238                     if (l.length() < (size_t)left) left = l.length();
239                     if (mode != MODE_LINE) {
240                         // we changed mode; remainder of l is invalid.
241                         l.clear();
242                         bufpos = total - left;
243                         break;
244                     }
245                 }
246                 if (pos == string::npos)
247                     left = 0;
248                 line = l;
249                 break;
250             case MODE_DATA:
251                 while(left > 0) {
252                     if(dataBytes == -1) {
253                         fire(BufferedSocketListener::Data(), &inbuf[bufpos], left);
254                         bufpos += (left - rollback);
255                         left = rollback;
256                         rollback = 0;
257                     } else {
258                         int high = (int)min(dataBytes, (int64_t)left);
259                         fire(BufferedSocketListener::Data(), &inbuf[bufpos], high);
260                         bufpos += high;
261                         left -= high;
262 
263                         dataBytes -= high;
264                         if(dataBytes == 0) {
265                             mode = MODE_LINE;
266                             fire(BufferedSocketListener::ModeChange());
267                         }
268                     }
269                 }
270                 break;
271         }
272     }
273 
274     if(mode == MODE_LINE && line.size() > static_cast<size_t>(SETTING(MAX_COMMAND_LENGTH))) {
275         throw SocketException(_("Maximum command length exceeded"));
276     }
277 }
278 
threadSendFile(InputStream * file)279 void BufferedSocket::threadSendFile(InputStream* file) {
280     if(state != RUNNING)
281         return;
282 
283     if(disconnecting)
284         return;
285     dcassert(file != NULL);
286     size_t sockSize = (size_t)sock->getSocketOptInt(SO_SNDBUF);
287     size_t bufSize = max(sockSize, (size_t)64*1024);
288 
289     ByteVector readBuf(bufSize);
290     ByteVector writeBuf(bufSize);
291 
292     size_t readPos = 0;
293 
294     bool readDone = false;
295     dcdebug("Starting threadSend\n");
296     while(!disconnecting) {
297         if(!readDone && readBuf.size() > readPos) {
298             // Fill read buffer
299             size_t bytesRead = readBuf.size() - readPos;
300             size_t actual = file->read(&readBuf[readPos], bytesRead);
301 
302             if(bytesRead > 0) {
303                 fire(BufferedSocketListener::BytesSent(), bytesRead, 0);
304             }
305 
306             if(actual == 0) {
307                 readDone = true;
308             } else {
309                 readPos += actual;
310             }
311         }
312 
313         if(readDone && readPos == 0) {
314             fire(BufferedSocketListener::TransmitDone());
315             return;
316         }
317 
318         readBuf.swap(writeBuf);
319         readBuf.resize(bufSize);
320         writeBuf.resize(readPos);
321         readPos = 0;
322 
323         size_t writePos = 0, writeSize = 0;
324         int written = 0;
325 
326         while(writePos < writeBuf.size()) {
327             if(disconnecting)
328                 return;
329 
330             if(written == -1) {
331                 // workaround for OpenSSL (crashes when previous write failed and now retrying with different writeSize)
332                 try {
333                     written = sock->write(&writeBuf[writePos], writeSize);
334                 } catch(const Exception&) {
335                     // ...
336                 }
337             } else {
338                 writeSize = min(sockSize / 2, writeBuf.size() - writePos);
339                 written = ThrottleManager::getInstance()->write(sock.get(), &writeBuf[writePos], writeSize);
340             }
341 
342             if(written > 0) {
343                 writePos += written;
344 
345                 fire(BufferedSocketListener::BytesSent(), 0, written);
346 
347             } else if(written == -1) {
348                 if(!readDone && readPos < readBuf.size()) {
349                     // Read a little since we're blocking anyway...
350                     size_t bytesRead = min(readBuf.size() - readPos, readBuf.size() / 2);
351                     size_t actual = file->read(&readBuf[readPos], bytesRead);
352 
353                     if(bytesRead > 0) {
354                         fire(BufferedSocketListener::BytesSent(), bytesRead, 0);
355                     }
356 
357                     if(actual == 0) {
358                         readDone = true;
359                     } else {
360                         readPos += actual;
361                     }
362                 } else {
363                     while(!disconnecting) {
364                         int w = sock->wait(POLL_TIMEOUT, Socket::WAIT_WRITE | Socket::WAIT_READ);
365                         if(w & Socket::WAIT_READ) {
366                             threadRead();
367                         }
368                         if(w & Socket::WAIT_WRITE) {
369                             break;
370                         }
371                     }
372                 }
373             }
374         }
375     }
376 }
377 
write(const char * aBuf,size_t aLen)378 void BufferedSocket::write(const char* aBuf, size_t aLen) noexcept {
379     if(!sock.get())
380         return;
381     Lock l(cs);
382     if(writeBuf.empty())
383         addTask(SEND_DATA, 0);
384 
385     writeBuf.insert(writeBuf.end(), aBuf, aBuf+aLen);
386 }
387 
threadSendData()388 void BufferedSocket::threadSendData() {
389     if(state != RUNNING)
390         return;
391 
392     {
393         Lock l(cs);
394         if(writeBuf.empty())
395             return;
396 
397         writeBuf.swap(sendBuf);
398     }
399 
400     size_t left = sendBuf.size();
401     size_t done = 0;
402     while(left > 0) {
403         if(disconnecting) {
404             return;
405         }
406 
407         int w = sock->wait(POLL_TIMEOUT, Socket::WAIT_READ | Socket::WAIT_WRITE);
408 
409         if(w & Socket::WAIT_READ) {
410             threadRead();
411         }
412 
413         if(w & Socket::WAIT_WRITE) {
414             int n = sock->write(&sendBuf[done], left);
415             if(n > 0) {
416                 left -= n;
417                 done += n;
418             }
419         }
420     }
421     sendBuf.clear();
422 }
423 
checkEvents()424 bool BufferedSocket::checkEvents() {
425     while(state == RUNNING ? taskSem.wait(0) : taskSem.wait()) {
426         pair<Tasks, unique_ptr<TaskData> > p;
427         {
428             Lock l(cs);
429             dcassert(!tasks.empty());
430             p = move(tasks.front());
431             tasks.erase(tasks.begin());
432         }
433 
434         if(p.first == SHUTDOWN) {
435             return false;
436         } else if(p.first == UPDATED) {
437             fire(BufferedSocketListener::Updated());
438             continue;
439         }
440 
441         if(state == STARTING) {
442             if(p.first == CONNECT) {
443                 ConnectInfo* ci = static_cast<ConnectInfo*>(p.second.get());
444                 threadConnect(ci->addr, ci->port, ci->localPort, ci->natRole, ci->proxy);
445             } else if(p.first == ACCEPTED) {
446                 threadAccept();
447             } else {
448                 dcdebug("%d unexpected in STARTING state\n", p.first);
449             }
450         } else if(state == RUNNING) {
451             if(p.first == SEND_DATA) {
452                 threadSendData();
453             } else if(p.first == SEND_FILE) {
454                 threadSendFile(static_cast<SendFileInfo*>(p.second.get())->stream); break;
455             } else if(p.first == DISCONNECT) {
456                 fail(_("Disconnected"));
457             } else {
458                 dcdebug("%d unexpected in RUNNING state\n", p.first);
459             }
460         }
461     }
462     return true;
463 }
464 
checkSocket()465 void BufferedSocket::checkSocket() {
466     int waitFor = sock->wait(POLL_TIMEOUT, Socket::WAIT_READ);
467 
468     if(waitFor & Socket::WAIT_READ) {
469         threadRead();
470     }
471 }
472 
473 /**
474  * Main task dispatcher for the buffered socket abstraction.
475  * @todo Fix the polling...
476  */
run()477 int BufferedSocket::run() {
478     dcdebug("BufferedSocket::run() start %p\n", (void*)this);
479     setThreadName("BufferedSocket");
480     while(true) {
481         try {
482             if(!checkEvents()) {
483                 break;
484             }
485             if(state == RUNNING) {
486                 checkSocket();
487             }
488         } catch(const Exception& e) {
489             fail(e.getError());
490         }
491     }
492     dcdebug("BufferedSocket::run() end %p\n", (void*)this);
493     delete this;
494     return 0;
495 }
496 
fail(const string & aError)497 void BufferedSocket::fail(const string& aError) {
498     if(sock.get()) {
499         sock->disconnect();
500     }
501 
502     if(state == RUNNING) {
503         state = FAILED;
504         fire(BufferedSocketListener::Failed(), aError);
505     }
506 }
507 
shutdown()508 void BufferedSocket::shutdown() {
509     Lock l(cs);
510     disconnecting = true;
511     addTask(SHUTDOWN, 0);
512 }
513 
addTask(Tasks task,TaskData * data)514 void BufferedSocket::addTask(Tasks task, TaskData* data) {
515     dcassert(task == DISCONNECT || task == SHUTDOWN || task == UPDATED || sock.get());
516     tasks.push_back(make_pair(task, unique_ptr<TaskData>(data))); taskSem.signal();
517 }
518 
519 } // namespace dcpp
520