1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20 
21 /**
22  * SocketListener.cpp
23  *
24  * Berkeley Socket-based ListenerService implementation.
25  */
26 
27 #include "internal.h"
28 #include "exceptions.h"
29 #include "ServiceProvider.h"
30 #include "SPConfig.h"
31 #include "remoting/impl/SocketListener.h"
32 
33 #include <errno.h>
34 #include <stack>
35 #include <sstream>
36 #include <boost/lexical_cast.hpp>
37 #include <xercesc/sax/SAXException.hpp>
38 #include <xercesc/util/XMLUniDefs.hpp>
39 #include <xercesc/util/OutOfMemoryException.hpp>
40 
41 #include <xmltooling/util/NDC.h>
42 #include <xmltooling/util/XMLHelper.h>
43 
44 #ifndef WIN32
45 # include <netinet/in.h>
46 #endif
47 
48 #ifdef HAVE_UNISTD_H
49 # include <unistd.h>
50 #endif
51 
52 using namespace shibsp;
53 using namespace xmltooling;
54 using namespace std;
55 
56 using xercesc::DOMElement;
57 using boost::lexical_cast;
58 
59 namespace shibsp {
60 
61     // Manages the pool of connections
62     class SocketPool
63     {
64     public:
SocketPool(Category & log,const SocketListener * listener)65         SocketPool(Category& log, const SocketListener* listener)
66             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
67         ~SocketPool();
68         SocketListener::ShibSocket get(bool newSocket=false);
69         void put(SocketListener::ShibSocket s);
70 
71     private:
72         SocketListener::ShibSocket connect();
73 
74         Category& m_log;
75         const SocketListener* m_listener;
76         boost::scoped_ptr<Mutex> m_lock;
77         stack<SocketListener::ShibSocket> m_pool;
78     };
79 
80     // Worker threads in server
81     class ServerThread {
82     public:
83         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
84         ~ServerThread();
85         void run();
86         int job();  // Return -1 on error, 1 for closed, 0 for success
87 
88     private:
89         SocketListener::ShibSocket m_sock;
90         Thread* m_child;
91         SocketListener* m_listener;
92         string m_id;
93         char m_buf[16384];
94     };
95 }
96 
connect()97 SocketListener::ShibSocket SocketPool::connect()
98 {
99 #ifdef _DEBUG
100     NDC ndc("connect");
101 #endif
102 
103     m_log.debug("trying to connect to listener");
104 
105     SocketListener::ShibSocket sock;
106     if (!m_listener->create(sock)) {
107         m_log.error("cannot create socket");
108         throw ListenerException("Cannot create socket");
109     }
110 
111     bool connected = false;
112     int num_tries = 3;
113 
114     for (int i = num_tries-1; i >= 0; i--) {
115         if (m_listener->connect(sock)) {
116             connected = true;
117             break;
118         }
119 
120         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
121 
122         if (i) {
123 #ifdef WIN32
124             Sleep(2000*(num_tries-i));
125 #else
126             sleep(2*(num_tries-i));
127 #endif
128         }
129     }
130 
131     if (!connected) {
132         m_log.crit("socket server unavailable, failing");
133         m_listener->close(sock);
134         throw ListenerException("Cannot connect to shibd process, a site administrator should be notified that this web server has malfunctioned.");
135     }
136 
137     m_log.debug("socket (%u) connected successfully", sock);
138     return sock;
139 }
140 
~SocketPool()141 SocketPool::~SocketPool()
142 {
143     while (!m_pool.empty()) {
144 #ifdef WIN32
145         closesocket(m_pool.top());
146 #else
147         ::close(m_pool.top());
148 #endif
149         m_pool.pop();
150     }
151 }
152 
get(bool newSocket)153 SocketListener::ShibSocket SocketPool::get(bool newSocket)
154 {
155     if (newSocket)
156         return connect();
157 
158     m_lock->lock();
159     if (m_pool.empty()) {
160         m_lock->unlock();
161         return connect();
162     }
163     SocketListener::ShibSocket ret=m_pool.top();
164     m_pool.pop();
165     m_lock->unlock();
166     return ret;
167 }
168 
put(SocketListener::ShibSocket s)169 void SocketPool::put(SocketListener::ShibSocket s)
170 {
171     Lock lock(m_lock);
172     m_pool.push(s);
173 }
174 
SocketListener(const DOMElement * e)175 SocketListener::SocketListener(const DOMElement* e)
176     : m_catchAll(false), log(&Category::getInstance(SHIBSP_LOGCAT ".Listener")),
177         m_shutdown(nullptr), m_stackSize(0), m_socket((ShibSocket)0)
178 {
179     // Are we a client?
180     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
181         m_socketpool.reset(new SocketPool(*log,this));
182     }
183     // Are we a server?
184     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
185         m_child_lock.reset(Mutex::create());
186         m_child_wait.reset(CondWait::create());
187 
188         static const XMLCh stackSize[] = UNICODE_LITERAL_9(s,t,a,c,k,S,i,z,e);
189         m_stackSize = XMLHelper::getAttrInt(e, 0, stackSize) * 1024;
190     }
191 }
192 
~SocketListener()193 SocketListener::~SocketListener()
194 {
195 }
196 
init(bool force)197 bool SocketListener::init(bool force)
198 {
199 #ifdef _DEBUG
200     NDC ndc("init");
201 #endif
202     log->info("listener service starting");
203 
204     ServiceProvider* sp = SPConfig::getConfig().getServiceProvider();
205     sp->lock();
206     const PropertySet* props = sp->getPropertySet("OutOfProcess");
207     if (props) {
208         pair<bool,bool> flag = props->getBool("catchAll");
209         m_catchAll = flag.first && flag.second;
210     }
211     sp->unlock();
212 
213     if (!create(m_socket)) {
214         log->crit("failed to create socket");
215         return false;
216     }
217     if (!bind(m_socket, force)) {
218         this->close(m_socket);
219         log->crit("failed to bind to socket.");
220         return false;
221     }
222 
223     return true;
224 }
225 
run(bool * shutdown)226 bool SocketListener::run(bool* shutdown)
227 {
228 #ifdef _DEBUG
229     NDC ndc("run");
230 #endif
231     // Save flag to monitor for shutdown request.
232     m_shutdown = shutdown;
233     unsigned long count = 0;
234 
235     while (!*m_shutdown) {
236         fd_set readfds;
237         FD_ZERO(&readfds);
238         FD_SET(m_socket, &readfds);
239         struct timeval tv = { 0, 0 };
240         tv.tv_sec = 5;
241 
242         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
243 #ifdef WIN32
244             case SOCKET_ERROR:
245 #else
246             case -1:
247 #endif
248                 if (errno == EINTR) continue;
249                 log_error();
250                 log->error("select() on main listener socket failed");
251                 *m_shutdown = true;
252                 break;
253 
254             case 0:
255                 continue;
256 
257             default:
258             {
259                 // Accept the connection.
260                 SocketListener::ShibSocket newsock;
261                 if (!accept(m_socket, newsock)) {
262                     log->crit("failed to accept incoming socket connection");
263                     continue;
264                 }
265 
266                 // We throw away the result because the children manage themselves...
267                 try {
268                     new ServerThread(newsock, this, ++count);
269                 }
270                 catch (exception& ex) {
271                     log->crit("exception starting new server thread to service incoming request: %s", ex.what());
272                 }
273                 catch (...) {
274                     log->crit("unknown error starting new server thread to service incoming request");
275                     if (!m_catchAll)
276                         *m_shutdown = true;
277                 }
278             }
279         }
280     }
281     log->info("listener service shutting down");
282 
283     // Wait for all children to exit.
284     m_child_lock->lock();
285     while (!m_children.empty())
286         m_child_wait->wait(m_child_lock.get());
287     m_child_lock->unlock();
288 
289     return true;
290 }
291 
term()292 void SocketListener::term()
293 {
294     this->close(m_socket);
295     m_socket=(ShibSocket)0;
296 }
297 
send(const DDF & in)298 DDF SocketListener::send(const DDF& in)
299 {
300 #ifdef _DEBUG
301     NDC ndc("send");
302 #endif
303 
304     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
305 
306     // Serialize data for transmission.
307     ostringstream os;
308     os << in;
309     string ostr(os.str());
310 
311     // Loop on the RPC in case we lost contact the first time through
312 #ifdef WIN32
313     u_long len;
314 #else
315     uint32_t len;
316 #endif
317     int retry = 1;
318     SocketListener::ShibSocket sock;
319     while (retry >= 0) {
320         // On second time in, just get a new socket.
321         sock = m_socketpool->get(retry == 0);
322 
323         int outlen = ostr.length();
324         len = htonl(outlen);
325         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
326             log_error();
327             this->close(sock);
328             if (retry)
329                 retry--;
330             else
331                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
332         }
333         else {
334             // SUCCESS.
335             retry = -1;
336         }
337     }
338 
339     log->debug("send completed, reading response message");
340 
341     // Read the message.
342     while (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
343     	if (errno == EINTR) continue;	// Apparently this happens when a signal interrupts the blocking call.
344         log->error("error reading size of output message");
345         this->close(sock);
346         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
347     }
348     len = ntohl(len);
349 
350     char buf[16384];
351     int size_read;
352     stringstream is;
353     while (len) {
354     	size_read = recv(sock, buf, sizeof(buf));
355     	if (size_read > 0) {
356             is.write(buf, size_read);
357             len -= size_read;
358     	}
359     	else if (errno != EINTR) {
360     		break;
361     	}
362     }
363 
364     if (len) {
365         log->error("error reading output message from socket");
366         this->close(sock);
367         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
368     }
369 
370     m_socketpool->put(sock);
371 
372     // Unmarshall data.
373     DDF out;
374     is >> out;
375 
376     // Check for exception to unmarshall and throw, otherwise return.
377     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
378         // Reconstitute exception object.
379         DDFJanitor jout(out);
380         XMLToolingException* except=nullptr;
381         try {
382             except=XMLToolingException::fromString(out.string());
383             log->error("remoted message returned an error: %s", except->what());
384         }
385         catch (const XMLToolingException& e) {
386             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
387             log->error("XML was: %s", out.string());
388             throw ListenerException("Remote call failed with an unparsable exception.");
389         }
390 
391         boost::scoped_ptr<XMLToolingException> wrapper(except);
392         wrapper->raise();
393     }
394 
395     return out;
396 }
397 
log_error(const char * fn) const398 bool SocketListener::log_error(const char* fn) const
399 {
400     if (!fn)
401         fn = "unknown";
402 #ifdef WIN32
403     int rc=WSAGetLastError();
404     if (rc == WSAECONNRESET) {
405         log->debug("socket connection reset");
406         return false;
407     }
408 #else
409     int rc=errno;
410 #endif
411     const char *msg;
412 #ifdef HAVE_STRERROR_R
413     char buf[256];
414 #ifdef STRERROR_R_CHAR_P
415     msg = strerror_r(rc,buf,sizeof(buf));
416 #else
417     msg = strerror_r(rc,buf,sizeof(buf)) ? "<translation failed>" : buf;
418 #endif
419 #else
420     msg=strerror(rc);
421 #endif
422     log->error("failed socket call (%s), result (%d): %s", fn, rc, isprint(*msg) ? msg : "no message");
423     return false;
424 }
425 
426 // actual function run in listener on server threads
server_thread_fn(void * arg)427 void* server_thread_fn(void* arg)
428 {
429     ServerThread* child = (ServerThread*)arg;
430 
431 #ifndef WIN32
432     // First, let's block all signals
433     Thread::mask_all_signals();
434 #endif
435 
436     // Run the child until it exits.
437     child->run();
438 
439     // Now we can clean up and exit the thread.
440     delete child;
441     return nullptr;
442 }
443 
ServerThread(SocketListener::ShibSocket & s,SocketListener * listener,unsigned long id)444 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
445     : m_sock(s), m_child(nullptr), m_listener(listener)
446 {
447 
448     m_id = string("[") + lexical_cast<string>(id) + "]";
449 
450     // Create the child thread
451     m_child = Thread::create(server_thread_fn, (void*)this, m_listener->m_stackSize);
452     m_child->detach();
453 }
454 
~ServerThread()455 ServerThread::~ServerThread()
456 {
457     // Then lock the children map, remove this socket/thread, signal waiters, and return
458     m_listener->m_child_lock->lock();
459     m_listener->m_children.erase(m_sock);
460     m_listener->m_child_lock->unlock();
461     m_listener->m_child_wait->signal();
462 
463     delete m_child;
464 }
465 
run()466 void ServerThread::run()
467 {
468     NDC ndc(m_id);
469 
470     // Before starting up, make sure we fully "own" this socket.
471     m_listener->m_child_lock->lock();
472     while (m_listener->m_children.find(m_sock) != m_listener->m_children.end())
473         m_listener->m_child_wait->wait(m_listener->m_child_lock.get());
474     m_listener->m_children[m_sock] = m_child;
475     m_listener->m_child_lock->unlock();
476 
477     int result;
478     fd_set readfds;
479     struct timeval tv = { 0, 0 };
480 
481     while(!*(m_listener->m_shutdown)) {
482         FD_ZERO(&readfds);
483         FD_SET(m_sock, &readfds);
484         tv.tv_sec = 1;
485 
486         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
487 #ifdef WIN32
488         case SOCKET_ERROR:
489 #else
490         case -1:
491 #endif
492             if (errno == EINTR) continue;
493             m_listener->log_error();
494             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
495             return;
496 
497         case 0:
498             break;
499 
500         default:
501             result = job();
502             if (result) {
503                 if (result < 0) {
504                     m_listener->log_error();
505                     m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
506                 }
507                 m_listener->close(m_sock);
508                 return;
509             }
510         }
511     }
512 }
513 
job()514 int ServerThread::job()
515 {
516     Category& log = Category::getInstance(SHIBSP_LOGCAT ".Listener");
517 
518     bool incomingError = true;  // set false once incoming message is received
519     ostringstream sink;
520 #ifdef WIN32
521     u_long len;
522 #else
523     uint32_t len;
524 #endif
525 
526     try {
527         // Read the message.
528         int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
529         if (readlength == 0) {
530             log.info("detected socket closure, shutting down worker thread");
531             return 1;
532         }
533         else if (readlength != sizeof(len)) {
534             log.error("error reading size of input message");
535             return -1;
536         }
537         len = ntohl(len);
538 
539         int size_read;
540         stringstream is;
541         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
542             is.write(m_buf, size_read);
543             len -= size_read;
544         }
545 
546         if (len) {
547             log.error("error reading input message from socket");
548             return -1;
549         }
550 
551         // Unmarshall the message.
552         DDF in;
553         DDFJanitor jin(in);
554         is >> in;
555 
556         string appid;
557         const char* aid = in["application_id"].string();
558         if (aid)
559             appid = string("[") + aid + "]";
560         NDC ndc(appid);
561 
562         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
563 
564         incomingError = false;
565 
566         // Dispatch the message.
567         m_listener->receive(in, sink);
568     }
569     catch (const xercesc::DOMException& e) {
570         auto_ptr_char temp(e.getMessage());
571         if (incomingError)
572             log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
573         XMLParserException ex(string("DOM error: ") + (temp.get() ? temp.get() : "no message"));
574         DDF out=DDF("exception").string(ex.toString().c_str());
575         DDFJanitor jout(out);
576         sink << out;
577     }
578     catch (const xercesc::SAXException& e) {
579         auto_ptr_char temp(e.getMessage());
580         if (incomingError)
581             log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
582         XMLParserException ex(string("SAX error: ") + (temp.get() ? temp.get() : "no message"));
583         DDF out=DDF("exception").string(ex.toString().c_str());
584         DDFJanitor jout(out);
585         sink << out;
586     }
587     catch (const xercesc::XMLException& e) {
588         auto_ptr_char temp(e.getMessage());
589         if (incomingError)
590             log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
591         XMLParserException ex(string("Xerces error: ") + (temp.get() ? temp.get() : "no message"));
592         DDF out=DDF("exception").string(ex.toString().c_str());
593         DDFJanitor jout(out);
594         sink << out;
595     }
596     catch (const xercesc::OutOfMemoryException& e) {
597         auto_ptr_char temp(e.getMessage());
598         if (incomingError)
599             log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
600         XMLParserException ex(string("Out of memory error: ") + (temp.get() ? temp.get() : "no message"));
601         DDF out=DDF("exception").string(ex.toString().c_str());
602         DDFJanitor jout(out);
603         sink << out;
604     }
605     catch (const XMLToolingException& e) {
606         if (incomingError)
607             log.error("error processing incoming message: %s", e.what());
608         DDF out=DDF("exception").string(e.toString().c_str());
609         DDFJanitor jout(out);
610         sink << out;
611     }
612     catch (const exception& e) {
613         if (incomingError)
614             log.error("error processing incoming message: %s", e.what());
615         ListenerException ex(e.what());
616         DDF out=DDF("exception").string(ex.toString().c_str());
617         DDFJanitor jout(out);
618         sink << out;
619     }
620     catch (...) {
621         if (incomingError)
622             log.error("unexpected error processing incoming message");
623         if (!m_listener->m_catchAll)
624             throw;
625         ListenerException ex("An unexpected error occurred while processing an incoming message.");
626         DDF out=DDF("exception").string(ex.toString().c_str());
627         DDFJanitor jout(out);
628         sink << out;
629     }
630 
631     // Return whatever's available.
632     string response(sink.str());
633     int outlen = response.length();
634     len = htonl(outlen);
635     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
636         log.error("error sending output message size");
637         return -1;
638     }
639     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
640         log.error("error sending output message");
641         return -1;
642     }
643 
644     return 0;
645 }
646