1 /**
2 * Licensed to the University Corporation for Advanced Internet
3 * Development, Inc. (UCAID) under one or more contributor license
4 * agreements. See the NOTICE file distributed with this work for
5 * additional information regarding copyright ownership.
6 *
7 * UCAID licenses this file to you under the Apache License,
8 * Version 2.0 (the "License"); you may not use this file except
9 * in compliance with the License. You may obtain a copy of the
10 * License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing,
15 * software distributed under the License is distributed on an
16 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17 * either express or implied. See the License for the specific
18 * language governing permissions and limitations under the License.
19 */
20
21 /**
22 * SocketListener.cpp
23 *
24 * Berkeley Socket-based ListenerService implementation.
25 */
26
27 #include "internal.h"
28 #include "exceptions.h"
29 #include "ServiceProvider.h"
30 #include "SPConfig.h"
31 #include "remoting/impl/SocketListener.h"
32
33 #include <errno.h>
34 #include <stack>
35 #include <sstream>
36 #include <boost/lexical_cast.hpp>
37 #include <xercesc/sax/SAXException.hpp>
38 #include <xercesc/util/XMLUniDefs.hpp>
39 #include <xercesc/util/OutOfMemoryException.hpp>
40
41 #include <xmltooling/util/NDC.h>
42 #include <xmltooling/util/XMLHelper.h>
43
44 #ifndef WIN32
45 # include <netinet/in.h>
46 #endif
47
48 #ifdef HAVE_UNISTD_H
49 # include <unistd.h>
50 #endif
51
52 using namespace shibsp;
53 using namespace xmltooling;
54 using namespace std;
55
56 using xercesc::DOMElement;
57 using boost::lexical_cast;
58
59 namespace shibsp {
60
61 // Manages the pool of connections
62 class SocketPool
63 {
64 public:
SocketPool(Category & log,const SocketListener * listener)65 SocketPool(Category& log, const SocketListener* listener)
66 : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
67 ~SocketPool();
68 SocketListener::ShibSocket get(bool newSocket=false);
69 void put(SocketListener::ShibSocket s);
70
71 private:
72 SocketListener::ShibSocket connect();
73
74 Category& m_log;
75 const SocketListener* m_listener;
76 boost::scoped_ptr<Mutex> m_lock;
77 stack<SocketListener::ShibSocket> m_pool;
78 };
79
80 // Worker threads in server
81 class ServerThread {
82 public:
83 ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
84 ~ServerThread();
85 void run();
86 int job(); // Return -1 on error, 1 for closed, 0 for success
87
88 private:
89 SocketListener::ShibSocket m_sock;
90 Thread* m_child;
91 SocketListener* m_listener;
92 string m_id;
93 char m_buf[16384];
94 };
95 }
96
connect()97 SocketListener::ShibSocket SocketPool::connect()
98 {
99 #ifdef _DEBUG
100 NDC ndc("connect");
101 #endif
102
103 m_log.debug("trying to connect to listener");
104
105 SocketListener::ShibSocket sock;
106 if (!m_listener->create(sock)) {
107 m_log.error("cannot create socket");
108 throw ListenerException("Cannot create socket");
109 }
110
111 bool connected = false;
112 int num_tries = 3;
113
114 for (int i = num_tries-1; i >= 0; i--) {
115 if (m_listener->connect(sock)) {
116 connected = true;
117 break;
118 }
119
120 m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
121
122 if (i) {
123 #ifdef WIN32
124 Sleep(2000*(num_tries-i));
125 #else
126 sleep(2*(num_tries-i));
127 #endif
128 }
129 }
130
131 if (!connected) {
132 m_log.crit("socket server unavailable, failing");
133 m_listener->close(sock);
134 throw ListenerException("Cannot connect to shibd process, a site administrator should be notified that this web server has malfunctioned.");
135 }
136
137 m_log.debug("socket (%u) connected successfully", sock);
138 return sock;
139 }
140
~SocketPool()141 SocketPool::~SocketPool()
142 {
143 while (!m_pool.empty()) {
144 #ifdef WIN32
145 closesocket(m_pool.top());
146 #else
147 ::close(m_pool.top());
148 #endif
149 m_pool.pop();
150 }
151 }
152
get(bool newSocket)153 SocketListener::ShibSocket SocketPool::get(bool newSocket)
154 {
155 if (newSocket)
156 return connect();
157
158 m_lock->lock();
159 if (m_pool.empty()) {
160 m_lock->unlock();
161 return connect();
162 }
163 SocketListener::ShibSocket ret=m_pool.top();
164 m_pool.pop();
165 m_lock->unlock();
166 return ret;
167 }
168
put(SocketListener::ShibSocket s)169 void SocketPool::put(SocketListener::ShibSocket s)
170 {
171 Lock lock(m_lock);
172 m_pool.push(s);
173 }
174
SocketListener(const DOMElement * e)175 SocketListener::SocketListener(const DOMElement* e)
176 : m_catchAll(false), log(&Category::getInstance(SHIBSP_LOGCAT ".Listener")),
177 m_shutdown(nullptr), m_stackSize(0), m_socket((ShibSocket)0)
178 {
179 // Are we a client?
180 if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
181 m_socketpool.reset(new SocketPool(*log,this));
182 }
183 // Are we a server?
184 if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
185 m_child_lock.reset(Mutex::create());
186 m_child_wait.reset(CondWait::create());
187
188 static const XMLCh stackSize[] = UNICODE_LITERAL_9(s,t,a,c,k,S,i,z,e);
189 m_stackSize = XMLHelper::getAttrInt(e, 0, stackSize) * 1024;
190 }
191 }
192
~SocketListener()193 SocketListener::~SocketListener()
194 {
195 }
196
init(bool force)197 bool SocketListener::init(bool force)
198 {
199 #ifdef _DEBUG
200 NDC ndc("init");
201 #endif
202 log->info("listener service starting");
203
204 ServiceProvider* sp = SPConfig::getConfig().getServiceProvider();
205 sp->lock();
206 const PropertySet* props = sp->getPropertySet("OutOfProcess");
207 if (props) {
208 pair<bool,bool> flag = props->getBool("catchAll");
209 m_catchAll = flag.first && flag.second;
210 }
211 sp->unlock();
212
213 if (!create(m_socket)) {
214 log->crit("failed to create socket");
215 return false;
216 }
217 if (!bind(m_socket, force)) {
218 this->close(m_socket);
219 log->crit("failed to bind to socket.");
220 return false;
221 }
222
223 return true;
224 }
225
run(bool * shutdown)226 bool SocketListener::run(bool* shutdown)
227 {
228 #ifdef _DEBUG
229 NDC ndc("run");
230 #endif
231 // Save flag to monitor for shutdown request.
232 m_shutdown = shutdown;
233 unsigned long count = 0;
234
235 while (!*m_shutdown) {
236 fd_set readfds;
237 FD_ZERO(&readfds);
238 FD_SET(m_socket, &readfds);
239 struct timeval tv = { 0, 0 };
240 tv.tv_sec = 5;
241
242 switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
243 #ifdef WIN32
244 case SOCKET_ERROR:
245 #else
246 case -1:
247 #endif
248 if (errno == EINTR) continue;
249 log_error();
250 log->error("select() on main listener socket failed");
251 *m_shutdown = true;
252 break;
253
254 case 0:
255 continue;
256
257 default:
258 {
259 // Accept the connection.
260 SocketListener::ShibSocket newsock;
261 if (!accept(m_socket, newsock)) {
262 log->crit("failed to accept incoming socket connection");
263 continue;
264 }
265
266 // We throw away the result because the children manage themselves...
267 try {
268 new ServerThread(newsock, this, ++count);
269 }
270 catch (exception& ex) {
271 log->crit("exception starting new server thread to service incoming request: %s", ex.what());
272 }
273 catch (...) {
274 log->crit("unknown error starting new server thread to service incoming request");
275 if (!m_catchAll)
276 *m_shutdown = true;
277 }
278 }
279 }
280 }
281 log->info("listener service shutting down");
282
283 // Wait for all children to exit.
284 m_child_lock->lock();
285 while (!m_children.empty())
286 m_child_wait->wait(m_child_lock.get());
287 m_child_lock->unlock();
288
289 return true;
290 }
291
term()292 void SocketListener::term()
293 {
294 this->close(m_socket);
295 m_socket=(ShibSocket)0;
296 }
297
send(const DDF & in)298 DDF SocketListener::send(const DDF& in)
299 {
300 #ifdef _DEBUG
301 NDC ndc("send");
302 #endif
303
304 log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
305
306 // Serialize data for transmission.
307 ostringstream os;
308 os << in;
309 string ostr(os.str());
310
311 // Loop on the RPC in case we lost contact the first time through
312 #ifdef WIN32
313 u_long len;
314 #else
315 uint32_t len;
316 #endif
317 int retry = 1;
318 SocketListener::ShibSocket sock;
319 while (retry >= 0) {
320 // On second time in, just get a new socket.
321 sock = m_socketpool->get(retry == 0);
322
323 int outlen = ostr.length();
324 len = htonl(outlen);
325 if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
326 log_error();
327 this->close(sock);
328 if (retry)
329 retry--;
330 else
331 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
332 }
333 else {
334 // SUCCESS.
335 retry = -1;
336 }
337 }
338
339 log->debug("send completed, reading response message");
340
341 // Read the message.
342 while (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
343 if (errno == EINTR) continue; // Apparently this happens when a signal interrupts the blocking call.
344 log->error("error reading size of output message");
345 this->close(sock);
346 throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
347 }
348 len = ntohl(len);
349
350 char buf[16384];
351 int size_read;
352 stringstream is;
353 while (len) {
354 size_read = recv(sock, buf, sizeof(buf));
355 if (size_read > 0) {
356 is.write(buf, size_read);
357 len -= size_read;
358 }
359 else if (errno != EINTR) {
360 break;
361 }
362 }
363
364 if (len) {
365 log->error("error reading output message from socket");
366 this->close(sock);
367 throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
368 }
369
370 m_socketpool->put(sock);
371
372 // Unmarshall data.
373 DDF out;
374 is >> out;
375
376 // Check for exception to unmarshall and throw, otherwise return.
377 if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
378 // Reconstitute exception object.
379 DDFJanitor jout(out);
380 XMLToolingException* except=nullptr;
381 try {
382 except=XMLToolingException::fromString(out.string());
383 log->error("remoted message returned an error: %s", except->what());
384 }
385 catch (const XMLToolingException& e) {
386 log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
387 log->error("XML was: %s", out.string());
388 throw ListenerException("Remote call failed with an unparsable exception.");
389 }
390
391 boost::scoped_ptr<XMLToolingException> wrapper(except);
392 wrapper->raise();
393 }
394
395 return out;
396 }
397
log_error(const char * fn) const398 bool SocketListener::log_error(const char* fn) const
399 {
400 if (!fn)
401 fn = "unknown";
402 #ifdef WIN32
403 int rc=WSAGetLastError();
404 if (rc == WSAECONNRESET) {
405 log->debug("socket connection reset");
406 return false;
407 }
408 #else
409 int rc=errno;
410 #endif
411 const char *msg;
412 #ifdef HAVE_STRERROR_R
413 char buf[256];
414 #ifdef STRERROR_R_CHAR_P
415 msg = strerror_r(rc,buf,sizeof(buf));
416 #else
417 msg = strerror_r(rc,buf,sizeof(buf)) ? "<translation failed>" : buf;
418 #endif
419 #else
420 msg=strerror(rc);
421 #endif
422 log->error("failed socket call (%s), result (%d): %s", fn, rc, isprint(*msg) ? msg : "no message");
423 return false;
424 }
425
426 // actual function run in listener on server threads
server_thread_fn(void * arg)427 void* server_thread_fn(void* arg)
428 {
429 ServerThread* child = (ServerThread*)arg;
430
431 #ifndef WIN32
432 // First, let's block all signals
433 Thread::mask_all_signals();
434 #endif
435
436 // Run the child until it exits.
437 child->run();
438
439 // Now we can clean up and exit the thread.
440 delete child;
441 return nullptr;
442 }
443
ServerThread(SocketListener::ShibSocket & s,SocketListener * listener,unsigned long id)444 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
445 : m_sock(s), m_child(nullptr), m_listener(listener)
446 {
447
448 m_id = string("[") + lexical_cast<string>(id) + "]";
449
450 // Create the child thread
451 m_child = Thread::create(server_thread_fn, (void*)this, m_listener->m_stackSize);
452 m_child->detach();
453 }
454
~ServerThread()455 ServerThread::~ServerThread()
456 {
457 // Then lock the children map, remove this socket/thread, signal waiters, and return
458 m_listener->m_child_lock->lock();
459 m_listener->m_children.erase(m_sock);
460 m_listener->m_child_lock->unlock();
461 m_listener->m_child_wait->signal();
462
463 delete m_child;
464 }
465
run()466 void ServerThread::run()
467 {
468 NDC ndc(m_id);
469
470 // Before starting up, make sure we fully "own" this socket.
471 m_listener->m_child_lock->lock();
472 while (m_listener->m_children.find(m_sock) != m_listener->m_children.end())
473 m_listener->m_child_wait->wait(m_listener->m_child_lock.get());
474 m_listener->m_children[m_sock] = m_child;
475 m_listener->m_child_lock->unlock();
476
477 int result;
478 fd_set readfds;
479 struct timeval tv = { 0, 0 };
480
481 while(!*(m_listener->m_shutdown)) {
482 FD_ZERO(&readfds);
483 FD_SET(m_sock, &readfds);
484 tv.tv_sec = 1;
485
486 switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
487 #ifdef WIN32
488 case SOCKET_ERROR:
489 #else
490 case -1:
491 #endif
492 if (errno == EINTR) continue;
493 m_listener->log_error();
494 m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
495 return;
496
497 case 0:
498 break;
499
500 default:
501 result = job();
502 if (result) {
503 if (result < 0) {
504 m_listener->log_error();
505 m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
506 }
507 m_listener->close(m_sock);
508 return;
509 }
510 }
511 }
512 }
513
job()514 int ServerThread::job()
515 {
516 Category& log = Category::getInstance(SHIBSP_LOGCAT ".Listener");
517
518 bool incomingError = true; // set false once incoming message is received
519 ostringstream sink;
520 #ifdef WIN32
521 u_long len;
522 #else
523 uint32_t len;
524 #endif
525
526 try {
527 // Read the message.
528 int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
529 if (readlength == 0) {
530 log.info("detected socket closure, shutting down worker thread");
531 return 1;
532 }
533 else if (readlength != sizeof(len)) {
534 log.error("error reading size of input message");
535 return -1;
536 }
537 len = ntohl(len);
538
539 int size_read;
540 stringstream is;
541 while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
542 is.write(m_buf, size_read);
543 len -= size_read;
544 }
545
546 if (len) {
547 log.error("error reading input message from socket");
548 return -1;
549 }
550
551 // Unmarshall the message.
552 DDF in;
553 DDFJanitor jin(in);
554 is >> in;
555
556 string appid;
557 const char* aid = in["application_id"].string();
558 if (aid)
559 appid = string("[") + aid + "]";
560 NDC ndc(appid);
561
562 log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
563
564 incomingError = false;
565
566 // Dispatch the message.
567 m_listener->receive(in, sink);
568 }
569 catch (const xercesc::DOMException& e) {
570 auto_ptr_char temp(e.getMessage());
571 if (incomingError)
572 log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
573 XMLParserException ex(string("DOM error: ") + (temp.get() ? temp.get() : "no message"));
574 DDF out=DDF("exception").string(ex.toString().c_str());
575 DDFJanitor jout(out);
576 sink << out;
577 }
578 catch (const xercesc::SAXException& e) {
579 auto_ptr_char temp(e.getMessage());
580 if (incomingError)
581 log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
582 XMLParserException ex(string("SAX error: ") + (temp.get() ? temp.get() : "no message"));
583 DDF out=DDF("exception").string(ex.toString().c_str());
584 DDFJanitor jout(out);
585 sink << out;
586 }
587 catch (const xercesc::XMLException& e) {
588 auto_ptr_char temp(e.getMessage());
589 if (incomingError)
590 log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
591 XMLParserException ex(string("Xerces error: ") + (temp.get() ? temp.get() : "no message"));
592 DDF out=DDF("exception").string(ex.toString().c_str());
593 DDFJanitor jout(out);
594 sink << out;
595 }
596 catch (const xercesc::OutOfMemoryException& e) {
597 auto_ptr_char temp(e.getMessage());
598 if (incomingError)
599 log.error("error processing incoming message: %s", temp.get() ? temp.get() : "no message");
600 XMLParserException ex(string("Out of memory error: ") + (temp.get() ? temp.get() : "no message"));
601 DDF out=DDF("exception").string(ex.toString().c_str());
602 DDFJanitor jout(out);
603 sink << out;
604 }
605 catch (const XMLToolingException& e) {
606 if (incomingError)
607 log.error("error processing incoming message: %s", e.what());
608 DDF out=DDF("exception").string(e.toString().c_str());
609 DDFJanitor jout(out);
610 sink << out;
611 }
612 catch (const exception& e) {
613 if (incomingError)
614 log.error("error processing incoming message: %s", e.what());
615 ListenerException ex(e.what());
616 DDF out=DDF("exception").string(ex.toString().c_str());
617 DDFJanitor jout(out);
618 sink << out;
619 }
620 catch (...) {
621 if (incomingError)
622 log.error("unexpected error processing incoming message");
623 if (!m_listener->m_catchAll)
624 throw;
625 ListenerException ex("An unexpected error occurred while processing an incoming message.");
626 DDF out=DDF("exception").string(ex.toString().c_str());
627 DDFJanitor jout(out);
628 sink << out;
629 }
630
631 // Return whatever's available.
632 string response(sink.str());
633 int outlen = response.length();
634 len = htonl(outlen);
635 if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
636 log.error("error sending output message size");
637 return -1;
638 }
639 if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
640 log.error("error sending output message");
641 return -1;
642 }
643
644 return 0;
645 }
646