1 /*
2 * CommunicationsChannel.cpp
3 * Created by Woody Zenfell, III on Mon Sep 01 2003.
4 */
5
6 /*
7 Copyright (c) 2003, Woody Zenfell, III
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in
17 all copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 SOFTWARE.
26 */
27
28 #if !defined(DISABLE_NETWORKING)
29
30 #include "CommunicationsChannel.h"
31
32 #include "AStream.h"
33 #include "MessageInflater.h"
34 #include "MessageHandler.h"
35
36 #include <stdlib.h>
37 #include <iostream> // debugging
38 #include <cerrno>
39 #include "cseries.h"
40 #if defined(WIN32)
41 #include <winsock2.h> // hacky non-cross-platform setting of nonblocking
42 #else
43 #include <fcntl.h> // hacky non-cross-platform setting of nonblocking
44 #endif
45 #include <algorithm>
46
47 enum
48 {
49 // If any incoming message claims to be longer than this, we bail
50 kMaximumMessageLength = 4 * 1024 * 1024,
51
52 // Milliseconds we wait between pump() calls during receive[Specific]Message()
53 kSSRPumpInterval = 50,
54
55 // Milliseconds we wait between pump() calls during flushOutgoingMessages()
56 kFlushPumpInterval = kSSRPumpInterval,
57 };
58
59 // if you really want to read what this does, scroll down
60 static void MakeTCPsocketNonBlocking(TCPsocket *socket);
61
CommunicationsChannel()62 CommunicationsChannel::CommunicationsChannel()
63 : mConnected(false),
64 mSocket(NULL),
65 mMessageInflater(NULL),
66 mMessageHandler(NULL),
67 mMemento(NULL),
68 mIncomingHeaderPosition(0),
69 mIncomingMessage(NULL),
70 mIncomingMessagePosition(0),
71 mOutgoingHeaderPosition(0),
72 mOutgoingMessagePosition(0)
73 {
74 mTicksAtLastReceive = SDL_GetTicks();
75 mTicksAtLastSend = SDL_GetTicks();
76 }
77
78
79
CommunicationsChannel(TCPsocket inSocket)80 CommunicationsChannel::CommunicationsChannel(TCPsocket inSocket)
81 : mConnected(inSocket != NULL),
82 mSocket(inSocket),
83 mMessageInflater(NULL),
84 mMessageHandler(NULL),
85 mMemento(NULL),
86 mIncomingHeaderPosition(0),
87 mIncomingMessage(NULL),
88 mIncomingMessagePosition(0),
89 mOutgoingHeaderPosition(0),
90 mOutgoingMessagePosition(0)
91 {
92 mTicksAtLastReceive = SDL_GetTicks();
93 mTicksAtLastSend = SDL_GetTicks();
94 }
95
96
97
~CommunicationsChannel()98 CommunicationsChannel::~CommunicationsChannel()
99 {
100 disconnect();
101 }
102
103
104
105 CommunicationsChannel::CommunicationResult
receive_some(TCPsocket inSocket,byte * inBuffer,size_t & ioBufferPosition,size_t inBufferLength)106 CommunicationsChannel::receive_some(TCPsocket inSocket, byte* inBuffer, size_t& ioBufferPosition, size_t inBufferLength)
107 {
108 // std::cout << "Want to receive " << inBufferLength << " bytes; buffer position " << ioBufferPosition << std::endl;
109
110 if (inBufferLength == 0) return kComplete;
111
112 size_t theBytesLeft = inBufferLength - ioBufferPosition;
113
114 if(theBytesLeft > 0)
115 {
116 int theResult = SDLNet_TCP_Recv(inSocket, inBuffer + ioBufferPosition, theBytesLeft);
117
118 // std::cout << " theResult is " << theResult << std::endl;
119
120 // Unfortunately, SDLNet_TCP_Recv() often returns -1 even when there's no error, and I
121 // don't think I have any legitimate way to distinguish this case from a true error condition.
122 #ifdef SANE_RECV_RESULTS
123 if(theResult < 0)
124 {
125 disconnect();
126 return kError;
127 }
128 else
129 {
130 if(theResult > 0)
131 {
132 mTicksAtLastReceive = SDL_GetTicks();
133 }
134
135 ioBufferPosition += theResult;
136 return (ioBufferPosition == inBufferLength) ? kComplete : kIncomplete;
137 }
138 }
139 #else
140 if(theResult == 0)
141 {
142 // For some reason we get 0 back if the connection is lost ...
143 disconnect();
144 return kError;
145 }
146
147 if(theResult < 0)
148 {
149 // Please close your eyes for this part ... we get -1 back from SDL_net,
150 // then peek around behind its back to try to figure out why. YUCK
151 // Hmm surely this is doomed to fail on non-UNIXy systems? sigh ...
152 // Perhaps we should treat 0 and < 0 the same here, and change what it
153 // means to be connected. Maybe we could use the Get Peer function to
154 // detect connected/disconnected.
155
156 // grsmith: we could do that, or we could add another
157 // platform-specific hack
158 #ifdef WIN32
159 if (WSAGetLastError() == WSAEWOULDBLOCK) {
160 theResult = 0;
161 } else {
162 std::cout << "theResult == " << theResult << std::endl;
163 disconnect();
164 return kError;
165 }
166 #else
167 if(errno == EAGAIN)
168 {
169 theResult = 0;
170 }
171 else
172 {
173 std::cout << "theResult == " << theResult << " ; errno == " << errno << " ; strerror() == " << strerror(errno) << " ; SDL_GetError() == " << SDL_GetError() << std::endl;
174
175 disconnect();
176 return kError;
177 }
178 #endif
179 }
180 if(theResult > 0)
181 {
182 mTicksAtLastReceive = SDL_GetTicks();
183 }
184
185 ioBufferPosition += theResult;
186 } // if we actually expect to receive something
187
188 return (ioBufferPosition == inBufferLength) ? kComplete : kIncomplete;
189 #endif // SANE_RECV_RESULTS
190 }
191
192
193
194 CommunicationsChannel::CommunicationResult
send_some(TCPsocket inSocket,byte * inBuffer,size_t & ioBufferPosition,size_t inBufferLength)195 CommunicationsChannel::send_some(TCPsocket inSocket, byte* inBuffer, size_t& ioBufferPosition, size_t inBufferLength)
196 {
197 // std::cout << "Want to send " << inBufferLength << " bytes; buffer position " << ioBufferPosition << std::endl;
198
199 size_t theBytesLeft = inBufferLength - ioBufferPosition;
200
201 int theResult = SDLNet_TCP_Send(inSocket, inBuffer + ioBufferPosition, theBytesLeft);
202
203 // std::cout << " theResult is " << theResult << std::endl;
204
205 if(theResult < 0)
206 {
207 disconnect();
208 return kError;
209 }
210 else
211 {
212 if(theResult > 0)
213 mTicksAtLastSend = SDL_GetTicks();
214
215 ioBufferPosition += theResult;
216 return (ioBufferPosition == inBufferLength) ? kComplete : kIncomplete;
217 }
218 }
219
220
221
222 bool
receiveHeader()223 CommunicationsChannel::receiveHeader()
224 {
225 CommunicationResult theResult =
226 receive_some(mSocket, mIncomingHeader, mIncomingHeaderPosition, kHeaderPackedSize);
227
228 if(theResult == kComplete)
229 {
230 // Finished receiving a header
231 AIStreamBE theHeaderStream(mIncomingHeader, kHeaderPackedSize);
232
233 uint16 theMagic;
234 uint16 theMessageType;
235 uint32 theMessageLength;
236
237 theHeaderStream >> theMagic
238 >> theMessageType
239 >> theMessageLength;
240
241 // Incoming length includes header length
242 theMessageLength -= kHeaderPackedSize;
243
244 if(theMagic != kHeaderMagic || theMessageLength > kMaximumMessageLength)
245 {
246 disconnect();
247 }
248 else
249 {
250 // Successfully received a valid header; switch to receive-message mode
251 mIncomingMessage = new UninflatedMessage(theMessageType, theMessageLength);
252 mIncomingMessagePosition = 0;
253 }
254
255 // We should try to receive more stuff, since we got all we asked for.
256 return true;
257 }
258 else
259 {
260 // We got less than we wanted - no sense in trying for more.
261 return false;
262 }
263 }
264
265
266
267 bool
_receiveMessage()268 CommunicationsChannel::_receiveMessage()
269 {
270 CommunicationResult theResult =
271 receive_some(mSocket, mIncomingMessage->buffer(), mIncomingMessagePosition, mIncomingMessage->length());
272
273 if(theResult == kComplete)
274 {
275 // Received a complete message; inflate (if possible) then enqueue it
276 Message* theMessageToEnqueue = mIncomingMessage;
277
278 if(mMessageInflater != NULL)
279 {
280 theMessageToEnqueue = mMessageInflater->inflate(*mIncomingMessage);
281 delete mIncomingMessage;
282 }
283
284 mIncomingMessages.push_back(theMessageToEnqueue);
285
286 // No longer receiving message body - prepare to receive next header
287 mIncomingMessage = NULL;
288 mIncomingHeaderPosition = 0;
289
290 // We got all we wanted - so we should go again to see if there's more.
291 return true;
292 }
293 else
294 {
295 // Ran out of data to receive, or error - no sense looking for more data
296 return false;
297 }
298 }
299
300
301
302 bool
sendHeader()303 CommunicationsChannel::sendHeader()
304 {
305 CommunicationResult theResult =
306 send_some(mSocket, mOutgoingHeader, mOutgoingHeaderPosition, kHeaderPackedSize);
307
308 if(theResult == kComplete)
309 {
310 // Finished sending a header; switch to sending message now
311 mOutgoingMessagePosition = 0;
312
313 // We should try to send more stuff, since we sent all we asked to.
314 return true;
315 }
316 else
317 {
318 // We sent less than we wanted - no sense in trying for more.
319 return false;
320 }
321 }
322
323
324
325 bool
sendMessage()326 CommunicationsChannel::sendMessage()
327 {
328 UninflatedMessage* theOutgoingMessage = mOutgoingMessages.front();
329
330 CommunicationResult theResult =
331 send_some(mSocket, theOutgoingMessage->buffer(), mOutgoingMessagePosition, theOutgoingMessage->length());
332
333 if(theResult == kComplete)
334 {
335 // Sent a complete message; delete and dequeue it
336 delete theOutgoingMessage;
337 mOutgoingMessages.pop_front();
338
339 // No longer sending message body - prepare to send next header
340 mOutgoingHeaderPosition = 0;
341
342 // We sent all we wanted - so we should go again to see if we can do more.
343 return true;
344 }
345 else
346 {
347 // Could not send it all, or error - no sense trying for more data
348 return false;
349 }
350 }
351
352
353
354 void
pumpReceivingSide()355 CommunicationsChannel::pumpReceivingSide()
356 {
357 bool keepGoing = true;
358 while(keepGoing && mConnected)
359 {
360 if(mIncomingMessage != NULL)
361 {
362 // Already working on receiving message body
363 keepGoing = _receiveMessage();
364 }
365 else
366 {
367 // Not receiving message body - must be receiving message header then
368 keepGoing = receiveHeader();
369 }
370 }
371 }
372
373
374
375 void
pumpSendingSide()376 CommunicationsChannel::pumpSendingSide()
377 {
378 bool keepGoing = true;
379 while(keepGoing && mConnected && !mOutgoingMessages.empty())
380 {
381 if(mOutgoingHeaderPosition == 0)
382 {
383 // Need to fill packed header buffer with packed header
384 // We may end up doing this more than once if for some reason we can't
385 // send any data bytes to TCP ... but that's OK.
386 UninflatedMessage* theMessage = mOutgoingMessages.front();
387 AOStreamBE theHeaderStream(mOutgoingHeader, kHeaderPackedSize);
388 theHeaderStream << (Uint16)kHeaderMagic
389 << theMessage->inflatedType()
390 << (uint32)(theMessage->length() + kHeaderPackedSize);
391 }
392
393 if(mOutgoingHeaderPosition < kHeaderPackedSize)
394 {
395 keepGoing = sendHeader();
396 }
397 else
398 {
399 keepGoing = sendMessage();
400 }
401 }
402 }
403
404
405
406 void
pump()407 CommunicationsChannel::pump()
408 {
409 pumpSendingSide();
410 pumpReceivingSide();
411 }
412
dispatchOneIncomingMessage()413 bool CommunicationsChannel::dispatchOneIncomingMessage()
414 {
415 if (mIncomingMessages.empty()) return false;
416 Message* theMessage = mIncomingMessages.front();
417 if (messageHandler() != NULL) {
418 messageHandler()->handle(theMessage, this);
419 }
420 delete theMessage;
421 mIncomingMessages.pop_front();
422 return true;
423 }
424
425 void
dispatchIncomingMessages()426 CommunicationsChannel::dispatchIncomingMessages()
427 {
428 while (dispatchOneIncomingMessage());
429 }
430
431
432
433 void
enqueueOutgoingMessage(const Message & inMessage)434 CommunicationsChannel::enqueueOutgoingMessage(const Message& inMessage)
435 {
436 if(isConnected())
437 {
438 UninflatedMessage* theUninflatedMessage = inMessage.deflate();
439 mOutgoingMessages.push_back(theUninflatedMessage);
440 }
441 }
442
443 IPaddress
peerAddress() const444 CommunicationsChannel::peerAddress() const
445 {
446 return *(SDLNet_TCP_GetPeerAddress(mSocket));
447 }
448
449 void
connect(const IPaddress & inAddress)450 CommunicationsChannel::connect(const IPaddress& inAddress)
451 {
452 assert(!isConnected());
453
454 mIncomingHeaderPosition = 0;
455 mIncomingMessagePosition = 0;
456 delete mIncomingMessage;
457 mIncomingMessage = 0;
458
459 for(MessageQueue::iterator i = mIncomingMessages.begin(); i != mIncomingMessages.end(); ++i)
460 delete *i;
461
462 mIncomingMessages.clear();
463
464 // Have to copy the address since we get a const, but SDL_net takes a non-const
465 IPaddress theAddress = inAddress;
466 mSocket = SDLNet_TCP_Open(&theAddress);
467
468 if(mSocket != NULL)
469 {
470 mConnected = true;
471
472 mTicksAtLastReceive = SDL_GetTicks();
473 mTicksAtLastSend = SDL_GetTicks();
474
475 MakeTCPsocketNonBlocking(&mSocket);
476 }
477 }
478
479
480
481 void
connect(const std::string & inAddressString,uint16 inPort)482 CommunicationsChannel::connect(const std::string& inAddressString, uint16 inPort)
483 {
484 IPaddress theAddress;
485 // Have to copy the string since we get a const, but SDL_net takes a non-const
486 char* theDuplicateString = strdup(inAddressString.c_str());
487 int theResult = SDLNet_ResolveHost(&theAddress, theDuplicateString, inPort);
488 free(theDuplicateString);
489 if(theResult == 0)
490 {
491 connect(theAddress);
492 }
493 }
494
495
496
497 void
disconnect()498 CommunicationsChannel::disconnect()
499 {
500 if(mSocket != NULL)
501 {
502 SDLNet_TCP_Close(mSocket);
503 mSocket = NULL;
504 mConnected = false;
505 }
506
507 // Discard all data so next connect()ion starts with a clean slate
508 mOutgoingHeaderPosition = 0;
509 mOutgoingMessagePosition = 0;
510
511 for(UninflatedMessageQueue::iterator i = mOutgoingMessages.begin(); i != mOutgoingMessages.end(); ++i)
512 delete *i;
513
514 mOutgoingMessages.clear();
515 }
516
517
518
519 bool
isMessageAvailable()520 CommunicationsChannel::isMessageAvailable()
521 {
522 pump();
523
524 return !mIncomingMessages.empty();
525 }
526
527
528
529 // Call does not return unless (1) times out (NULL); (2) disconnected (NULL); or
530 // (3) some message received (pointer to inflated message object).
531 Message*
receiveMessage(Uint32 inOverallTimeout,Uint32 inInactivityTimeout)532 CommunicationsChannel::receiveMessage(Uint32 inOverallTimeout, Uint32 inInactivityTimeout)
533 {
534 // Here we give a backstop for our inactivity timeout
535 Uint32 theTicksAtStart = SDL_GetTicks();
536
537 Uint32 theDeadline = SDL_GetTicks() + inOverallTimeout;
538
539 pump();
540
541 while(SDL_GetTicks() - std::max(mTicksAtLastReceive, theTicksAtStart) < inInactivityTimeout
542 && SDL_GetTicks() < theDeadline
543 && isConnected()
544 && mIncomingMessages.empty())
545 {
546 SDL_Delay(kSSRPumpInterval);
547 pump();
548 }
549
550 Message* theMessage = NULL;
551
552 if(!mIncomingMessages.empty())
553 {
554 theMessage = mIncomingMessages.front();
555 mIncomingMessages.pop_front();
556 }
557
558 return theMessage;
559 }
560
561
562
563 // As above, but if messages of type other than inType are received, they're handled
564 // normally (so might want to install conservative Handler first)
565 Message*
receiveSpecificMessage(MessageTypeID inType,Uint32 inOverallTimeout,Uint32 inInactivityTimeout)566 CommunicationsChannel::receiveSpecificMessage(
567 MessageTypeID inType,
568 Uint32 inOverallTimeout,
569 Uint32 inInactivityTimeout)
570 {
571 Message* theMessage = NULL;
572 Uint32 theDeadline = SDL_GetTicks() + inOverallTimeout;
573
574 while(SDL_GetTicks() < theDeadline)
575 {
576 theMessage = receiveMessage(theDeadline - SDL_GetTicks(), inInactivityTimeout);
577
578 if(theMessage)
579 {
580 if(theMessage->type() == inType)
581 // Got our message
582 break;
583 else
584 {
585 // Got some other message - handle it and destroy it
586 if(messageHandler() != NULL)
587 {
588 messageHandler()->handle(theMessage, this);
589 }
590 delete theMessage;
591 theMessage = NULL;
592 }
593 }
594 else
595 // Other routine timed out or got disconnected
596 break;
597 }
598
599 return theMessage;
600 }
601
602
603
604 void
flushOutgoingMessages(bool shouldDispatchIncomingMessages,Uint32 inOverallTimeout,Uint32 inInactivityTimeout)605 CommunicationsChannel::flushOutgoingMessages(bool shouldDispatchIncomingMessages,
606 Uint32 inOverallTimeout,
607 Uint32 inInactivityTimeout)
608 {
609 Uint32 theDeadline = SDL_GetTicks() + inOverallTimeout;
610 Uint32 theTicksAtStart = SDL_GetTicks();
611
612 while(isConnected()
613 && !mOutgoingMessages.empty()
614 && SDL_GetTicks() < theDeadline
615 && SDL_GetTicks() - std::max(mTicksAtLastSend, theTicksAtStart) < inInactivityTimeout)
616 {
617 SDL_Delay(kFlushPumpInterval);
618 pump();
619 if(shouldDispatchIncomingMessages)
620 dispatchIncomingMessages();
621 }
622 }
623
624
multipleFlushOutgoingMessages(std::vector<CommunicationsChannel * > & channels,bool shouldDispatchIncomingMessages,Uint32 inOverallTimeout,Uint32 inInactivityTimeout)625 void CommunicationsChannel::multipleFlushOutgoingMessages(
626 std::vector<CommunicationsChannel *>& channels,
627 bool shouldDispatchIncomingMessages,
628 Uint32 inOverallTimeout,
629 Uint32 inInactivityTimeout)
630 {
631 Uint32 theDeadline = SDL_GetTicks() + inOverallTimeout;
632 Uint32 theTicksAtStart = SDL_GetTicks();
633
634 bool someoneIsStillActive = true;
635
636 while (SDL_GetTicks() < theDeadline && someoneIsStillActive)
637 {
638 someoneIsStillActive = false;
639
640 SDL_Delay(kFlushPumpInterval);
641
642 for (std::vector<CommunicationsChannel*>::iterator it = channels.begin(); it != channels.end(); it++)
643 {
644 if (!(*it)->mOutgoingMessages.empty() && SDL_GetTicks() - std::max((*it)->mTicksAtLastSend, theTicksAtStart) < inInactivityTimeout)
645 {
646 someoneIsStillActive = true;
647 }
648
649 (*it)->pump();
650 if (shouldDispatchIncomingMessages)
651 (*it)->dispatchIncomingMessages();
652 }
653
654 }
655
656 }
657
658
CommunicationsChannelFactory(uint16 inPort)659 CommunicationsChannelFactory::CommunicationsChannelFactory(uint16 inPort)
660 {
661 IPaddress theAddress;
662 theAddress.host = INADDR_ANY;
663 theAddress.port = SDL_SwapBE16(inPort);
664
665 mSocket = SDLNet_TCP_Open(&theAddress);
666 }
667
668
669
670 CommunicationsChannel*
newIncomingConnection()671 CommunicationsChannelFactory::newIncomingConnection()
672 {
673 CommunicationsChannel* theNewChannel = NULL;
674
675 if(isFunctional())
676 {
677 SDLNet_SocketSet theSocketSet = SDLNet_AllocSocketSet(1);
678 SDLNet_TCP_AddSocket(theSocketSet, mSocket);
679 if(SDLNet_CheckSockets(theSocketSet, 0) > 0) {
680 // Yee-haw! There's an incoming connection request.
681 TCPsocket theNewSocket = SDLNet_TCP_Accept(mSocket);
682 theNewChannel = new CommunicationsChannel(theNewSocket);
683 MakeTCPsocketNonBlocking(&theNewSocket);
684
685 }
686 SDLNet_FreeSocketSet(theSocketSet);
687 }
688
689 return theNewChannel;
690 }
691
692
693
~CommunicationsChannelFactory()694 CommunicationsChannelFactory::~CommunicationsChannelFactory()
695 {
696 SDLNet_TCP_Close(mSocket);
697 }
698
MakeTCPsocketNonBlocking(TCPsocket * socket)699 void MakeTCPsocketNonBlocking(TCPsocket *socket) {
700 // SET NONBLOCKING MODE
701 // XXX: this depends on intimate carnal knowledge of the SDL_net struct _UDPsocket
702 // if it changes that structure, we are hosed.
703
704 int fd = ((int *) (*socket))[1];
705 #if defined(WIN32)
706 u_long val = 1;
707 ioctlsocket(fd, FIONBIO, &val);
708 #else
709
710 fcntl(fd, F_SETFL, O_NONBLOCK);
711
712 #endif
713 }
714
715 #endif // !defined(DISABLE_NETWORKING)
716
717