1 /*
2 * Copyright (C) 2001-2012 Jacek Sieka, arnetheduck on gmail point com
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; either version 2 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17 */
18
19 #include "stdinc.h"
20
21 #include "BufferedSocket.h"
22
23 #include "TimerManager.h"
24 #include "SettingsManager.h"
25
26 #include "Streams.h"
27 #include "SSLSocket.h"
28 #include "CryptoManager.h"
29 #include "ZUtils.h"
30
31 #include "ThrottleManager.h"
32
33 namespace dcpp {
34
35 // Polling is used for tasks...should be fixed...
36 #define POLL_TIMEOUT 250
37
BufferedSocket(char aSeparator)38 BufferedSocket::BufferedSocket(char aSeparator) :
39 separator(aSeparator), mode(MODE_LINE), dataBytes(0), rollback(0), state(STARTING),
40 disconnecting(false)
41 {
42 start();
43
44 sockets.inc();
45 }
46
47 Atomic<long,memory_ordering_strong> BufferedSocket::sockets(0);
48
~BufferedSocket()49 BufferedSocket::~BufferedSocket() {
50 sockets.dec();
51 }
52
setMode(Modes aMode,size_t aRollback)53 void BufferedSocket::setMode (Modes aMode, size_t aRollback) {
54 if (mode == aMode) {
55 dcdebug ("WARNING: Re-entering mode %d\n", mode);
56 return;
57 }
58
59 switch (aMode) {
60 case MODE_LINE:
61 rollback = aRollback;
62 break;
63 case MODE_ZPIPE:
64 filterIn = std::unique_ptr<UnZFilter>(new UnZFilter);
65 break;
66 case MODE_DATA:
67 break;
68 }
69 mode = aMode;
70 }
71
setSocket(std::unique_ptr<Socket> s)72 void BufferedSocket::setSocket(std::unique_ptr<Socket> s) {
73 dcassert(!sock.get());
74 if(SETTING(SOCKET_IN_BUFFER) > 0)
75 s->setSocketOpt(SO_RCVBUF, SETTING(SOCKET_IN_BUFFER));
76 if(SETTING(SOCKET_OUT_BUFFER) > 0)
77 s->setSocketOpt(SO_SNDBUF, SETTING(SOCKET_OUT_BUFFER));
78 s->setSocketOpt(SO_REUSEADDR, 1); // NAT traversal
79
80 inbuf.resize(s->getSocketOptInt(SO_RCVBUF));
81
82 sock = move(s);
83 }
84
accept(const Socket & srv,bool secure,bool allowUntrusted)85 void BufferedSocket::accept(const Socket& srv, bool secure, bool allowUntrusted) {
86 dcdebug("BufferedSocket::accept() %p\n", (void*)this);
87
88 std::unique_ptr<Socket> s(secure ? CryptoManager::getInstance()->getServerSocket(allowUntrusted) : new Socket);
89
90 s->accept(srv);
91
92 setSocket(move(s));
93
94 Lock l(cs);
95 addTask(ACCEPTED, 0);
96 }
97
connect(const string & aAddress,uint16_t aPort,bool secure,bool allowUntrusted,bool proxy)98 void BufferedSocket::connect(const string& aAddress, uint16_t aPort, bool secure, bool allowUntrusted, bool proxy) {
99 connect(aAddress, aPort, 0, NAT_NONE, secure, allowUntrusted, proxy);
100 }
101
connect(const string & aAddress,uint16_t aPort,uint16_t localPort,NatRoles natRole,bool secure,bool allowUntrusted,bool proxy)102 void BufferedSocket::connect(const string& aAddress, uint16_t aPort, uint16_t localPort, NatRoles natRole, bool secure, bool allowUntrusted, bool proxy) {
103 dcdebug("BufferedSocket::connect() %p\n", (void*)this);
104 std::unique_ptr<Socket> s(secure ? (natRole == NAT_SERVER ? CryptoManager::getInstance()->getServerSocket(allowUntrusted) : CryptoManager::getInstance()->getClientSocket(allowUntrusted)) : new Socket);
105
106 s->create();
107 setSocket(move(s));
108 sock->bind(localPort, SETTING(BIND_IFACE)? sock->getIfaceI4(SETTING(BIND_IFACE_NAME)).c_str() : SETTING(BIND_ADDRESS));
109
110 Lock l(cs);
111 addTask(CONNECT, new ConnectInfo(aAddress, aPort, localPort, natRole, proxy && (SETTING(OUTGOING_CONNECTIONS) == SettingsManager::OUTGOING_SOCKS5)));
112 }
113
114 #define LONG_TIMEOUT 30000
115 #define SHORT_TIMEOUT 1000
threadConnect(const string & aAddr,uint16_t aPort,uint16_t localPort,NatRoles natRole,bool proxy)116 void BufferedSocket::threadConnect(const string& aAddr, uint16_t aPort, uint16_t localPort, NatRoles natRole, bool proxy) {
117 dcassert(state == STARTING);
118
119 dcdebug("threadConnect %s:%d/%d\n", aAddr.c_str(), (int)localPort, (int)aPort);
120 fire(BufferedSocketListener::Connecting());
121
122 const uint64_t endTime = GET_TICK() + LONG_TIMEOUT;
123 state = RUNNING;
124
125 while (GET_TICK() < endTime) {
126 dcdebug("threadConnect attempt to addr \"%s\"\n", aAddr.c_str());
127 try {
128 if(proxy) {
129 sock->socksConnect(aAddr, aPort, LONG_TIMEOUT);
130 } else {
131 sock->connect(aAddr, aPort);
132 }
133
134 bool connSucceeded;
135 while(!(connSucceeded = sock->waitConnected(POLL_TIMEOUT)) && endTime >= GET_TICK()) {
136 if(disconnecting) return;
137 }
138
139 if (connSucceeded) {
140 fire(BufferedSocketListener::Connected());
141 return;
142 }
143 }
144 catch (const SSLSocketException&) {
145 throw;
146 } catch (const SocketException&) {
147 if (natRole == NAT_NONE)
148 throw;
149 Thread::sleep(SHORT_TIMEOUT);
150 }
151 }
152
153 throw SocketException(_("Connection timeout"));
154 }
155
threadAccept()156 void BufferedSocket::threadAccept() {
157 dcassert(state == STARTING);
158
159 dcdebug("threadAccept\n");
160
161 state = RUNNING;
162
163 uint64_t startTime = GET_TICK();
164 while(!sock->waitAccepted(POLL_TIMEOUT)) {
165 if(disconnecting)
166 return;
167
168 if((startTime + 30000) < GET_TICK()) {
169 throw SocketException(_("Connection timeout"));
170 }
171 }
172 }
173
threadRead()174 void BufferedSocket::threadRead() {
175 if(state != RUNNING)
176 return;
177
178 int left = (mode == MODE_DATA) ? ThrottleManager::getInstance()->read(sock.get(), &inbuf[0], (int)inbuf.size()) : sock->read(&inbuf[0], (int)inbuf.size());
179 if(left == -1) {
180 // EWOULDBLOCK, no data received...
181 return;
182 } else if(left == 0) {
183 // This socket has been closed...
184 throw SocketException(_("Connection closed"));
185 }
186 string::size_type pos = 0;
187 // always uncompressed data
188 string l;
189 int bufpos = 0, total = left;
190
191 while (left > 0) {
192 switch (mode) {
193 case MODE_ZPIPE: {
194 const int BUF_SIZE = 1024;
195 // Special to autodetect nmdc connections...
196 string::size_type pos = 0;
197 boost::scoped_array<char> buffer(new char[BUF_SIZE]);
198 l = line;
199 // decompress all input data and store in l.
200 while (left) {
201 size_t in = BUF_SIZE;
202 size_t used = left;
203 bool ret = (*filterIn) (&inbuf[0] + total - left, used, &buffer[0], in);
204 left -= used;
205 l.append (&buffer[0], in);
206 // if the stream ends before the data runs out, keep remainder of data in inbuf
207 if (!ret) {
208 bufpos = total-left;
209 setMode (MODE_LINE, rollback);
210 break;
211 }
212 }
213 // process all lines
214 while ((pos = l.find(separator)) != string::npos) {
215 if(pos > 0) // check empty (only pipe) command and don't waste cpu with it ;o)
216 fire(BufferedSocketListener::Line(), l.substr(0, pos));
217 l.erase (0, pos + 1 /* separator char */);
218 }
219 // store remainder
220 line = l;
221
222 break;
223 }
224 case MODE_LINE:
225 // Special to autodetect nmdc connections...
226 if(separator == 0) {
227 if(inbuf[0] == '$') {
228 separator = '|';
229 } else {
230 separator = '\n';
231 }
232 }
233 l = line + string ((char*)&inbuf[bufpos], left);
234 while ((pos = l.find(separator)) != string::npos) {
235 if(pos > 0) // check empty (only pipe) command and don't waste cpu with it ;o)
236 fire(BufferedSocketListener::Line(), l.substr(0, pos));
237 l.erase (0, pos + 1 /* separator char */);
238 if (l.length() < (size_t)left) left = l.length();
239 if (mode != MODE_LINE) {
240 // we changed mode; remainder of l is invalid.
241 l.clear();
242 bufpos = total - left;
243 break;
244 }
245 }
246 if (pos == string::npos)
247 left = 0;
248 line = l;
249 break;
250 case MODE_DATA:
251 while(left > 0) {
252 if(dataBytes == -1) {
253 fire(BufferedSocketListener::Data(), &inbuf[bufpos], left);
254 bufpos += (left - rollback);
255 left = rollback;
256 rollback = 0;
257 } else {
258 int high = (int)min(dataBytes, (int64_t)left);
259 fire(BufferedSocketListener::Data(), &inbuf[bufpos], high);
260 bufpos += high;
261 left -= high;
262
263 dataBytes -= high;
264 if(dataBytes == 0) {
265 mode = MODE_LINE;
266 fire(BufferedSocketListener::ModeChange());
267 }
268 }
269 }
270 break;
271 }
272 }
273
274 if(mode == MODE_LINE && line.size() > static_cast<size_t>(SETTING(MAX_COMMAND_LENGTH))) {
275 throw SocketException(_("Maximum command length exceeded"));
276 }
277 }
278
threadSendFile(InputStream * file)279 void BufferedSocket::threadSendFile(InputStream* file) {
280 if(state != RUNNING)
281 return;
282
283 if(disconnecting)
284 return;
285 dcassert(file != NULL);
286 size_t sockSize = (size_t)sock->getSocketOptInt(SO_SNDBUF);
287 size_t bufSize = max(sockSize, (size_t)64*1024);
288
289 ByteVector readBuf(bufSize);
290 ByteVector writeBuf(bufSize);
291
292 size_t readPos = 0;
293
294 bool readDone = false;
295 dcdebug("Starting threadSend\n");
296 while(!disconnecting) {
297 if(!readDone && readBuf.size() > readPos) {
298 // Fill read buffer
299 size_t bytesRead = readBuf.size() - readPos;
300 size_t actual = file->read(&readBuf[readPos], bytesRead);
301
302 if(bytesRead > 0) {
303 fire(BufferedSocketListener::BytesSent(), bytesRead, 0);
304 }
305
306 if(actual == 0) {
307 readDone = true;
308 } else {
309 readPos += actual;
310 }
311 }
312
313 if(readDone && readPos == 0) {
314 fire(BufferedSocketListener::TransmitDone());
315 return;
316 }
317
318 readBuf.swap(writeBuf);
319 readBuf.resize(bufSize);
320 writeBuf.resize(readPos);
321 readPos = 0;
322
323 size_t writePos = 0, writeSize = 0;
324 int written = 0;
325
326 while(writePos < writeBuf.size()) {
327 if(disconnecting)
328 return;
329
330 if(written == -1) {
331 // workaround for OpenSSL (crashes when previous write failed and now retrying with different writeSize)
332 try {
333 written = sock->write(&writeBuf[writePos], writeSize);
334 } catch(const Exception&) {
335 // ...
336 }
337 } else {
338 writeSize = min(sockSize / 2, writeBuf.size() - writePos);
339 written = ThrottleManager::getInstance()->write(sock.get(), &writeBuf[writePos], writeSize);
340 }
341
342 if(written > 0) {
343 writePos += written;
344
345 fire(BufferedSocketListener::BytesSent(), 0, written);
346
347 } else if(written == -1) {
348 if(!readDone && readPos < readBuf.size()) {
349 // Read a little since we're blocking anyway...
350 size_t bytesRead = min(readBuf.size() - readPos, readBuf.size() / 2);
351 size_t actual = file->read(&readBuf[readPos], bytesRead);
352
353 if(bytesRead > 0) {
354 fire(BufferedSocketListener::BytesSent(), bytesRead, 0);
355 }
356
357 if(actual == 0) {
358 readDone = true;
359 } else {
360 readPos += actual;
361 }
362 } else {
363 while(!disconnecting) {
364 int w = sock->wait(POLL_TIMEOUT, Socket::WAIT_WRITE | Socket::WAIT_READ);
365 if(w & Socket::WAIT_READ) {
366 threadRead();
367 }
368 if(w & Socket::WAIT_WRITE) {
369 break;
370 }
371 }
372 }
373 }
374 }
375 }
376 }
377
write(const char * aBuf,size_t aLen)378 void BufferedSocket::write(const char* aBuf, size_t aLen) noexcept {
379 if(!sock.get())
380 return;
381 Lock l(cs);
382 if(writeBuf.empty())
383 addTask(SEND_DATA, 0);
384
385 writeBuf.insert(writeBuf.end(), aBuf, aBuf+aLen);
386 }
387
threadSendData()388 void BufferedSocket::threadSendData() {
389 if(state != RUNNING)
390 return;
391
392 {
393 Lock l(cs);
394 if(writeBuf.empty())
395 return;
396
397 writeBuf.swap(sendBuf);
398 }
399
400 size_t left = sendBuf.size();
401 size_t done = 0;
402 while(left > 0) {
403 if(disconnecting) {
404 return;
405 }
406
407 int w = sock->wait(POLL_TIMEOUT, Socket::WAIT_READ | Socket::WAIT_WRITE);
408
409 if(w & Socket::WAIT_READ) {
410 threadRead();
411 }
412
413 if(w & Socket::WAIT_WRITE) {
414 int n = sock->write(&sendBuf[done], left);
415 if(n > 0) {
416 left -= n;
417 done += n;
418 }
419 }
420 }
421 sendBuf.clear();
422 }
423
checkEvents()424 bool BufferedSocket::checkEvents() {
425 while(state == RUNNING ? taskSem.wait(0) : taskSem.wait()) {
426 pair<Tasks, unique_ptr<TaskData> > p;
427 {
428 Lock l(cs);
429 dcassert(!tasks.empty());
430 p = move(tasks.front());
431 tasks.erase(tasks.begin());
432 }
433
434 if(p.first == SHUTDOWN) {
435 return false;
436 } else if(p.first == UPDATED) {
437 fire(BufferedSocketListener::Updated());
438 continue;
439 }
440
441 if(state == STARTING) {
442 if(p.first == CONNECT) {
443 ConnectInfo* ci = static_cast<ConnectInfo*>(p.second.get());
444 threadConnect(ci->addr, ci->port, ci->localPort, ci->natRole, ci->proxy);
445 } else if(p.first == ACCEPTED) {
446 threadAccept();
447 } else {
448 dcdebug("%d unexpected in STARTING state\n", p.first);
449 }
450 } else if(state == RUNNING) {
451 if(p.first == SEND_DATA) {
452 threadSendData();
453 } else if(p.first == SEND_FILE) {
454 threadSendFile(static_cast<SendFileInfo*>(p.second.get())->stream); break;
455 } else if(p.first == DISCONNECT) {
456 fail(_("Disconnected"));
457 } else {
458 dcdebug("%d unexpected in RUNNING state\n", p.first);
459 }
460 }
461 }
462 return true;
463 }
464
checkSocket()465 void BufferedSocket::checkSocket() {
466 int waitFor = sock->wait(POLL_TIMEOUT, Socket::WAIT_READ);
467
468 if(waitFor & Socket::WAIT_READ) {
469 threadRead();
470 }
471 }
472
473 /**
474 * Main task dispatcher for the buffered socket abstraction.
475 * @todo Fix the polling...
476 */
run()477 int BufferedSocket::run() {
478 dcdebug("BufferedSocket::run() start %p\n", (void*)this);
479 setThreadName("BufferedSocket");
480 while(true) {
481 try {
482 if(!checkEvents()) {
483 break;
484 }
485 if(state == RUNNING) {
486 checkSocket();
487 }
488 } catch(const Exception& e) {
489 fail(e.getError());
490 }
491 }
492 dcdebug("BufferedSocket::run() end %p\n", (void*)this);
493 delete this;
494 return 0;
495 }
496
fail(const string & aError)497 void BufferedSocket::fail(const string& aError) {
498 if(sock.get()) {
499 sock->disconnect();
500 }
501
502 if(state == RUNNING) {
503 state = FAILED;
504 fire(BufferedSocketListener::Failed(), aError);
505 }
506 }
507
shutdown()508 void BufferedSocket::shutdown() {
509 Lock l(cs);
510 disconnecting = true;
511 addTask(SHUTDOWN, 0);
512 }
513
addTask(Tasks task,TaskData * data)514 void BufferedSocket::addTask(Tasks task, TaskData* data) {
515 dcassert(task == DISCONNECT || task == SHUTDOWN || task == UPDATED || sock.get());
516 tasks.push_back(make_pair(task, unique_ptr<TaskData>(data))); taskSem.signal();
517 }
518
519 } // namespace dcpp
520