1 /*
2  * Copyright (C) 2017-2018 Daniel Nicoletti <dantti12@gmail.com>
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this library; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
17  */
18 #include "tcpserverbalancer.h"
19 
20 #include "server.h"
21 #include "cwsgiengine.h"
22 #include "tcpserver.h"
23 #include "tcpsslserver.h"
24 
25 #include <QFile>
26 #include <QLoggingCategory>
27 
28 #include <QSslKey>
29 
30 #include <iostream>
31 
32 #ifdef Q_OS_LINUX
33 #include <arpa/inet.h>
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <fcntl.h>
37 #endif
38 
39 
40 Q_LOGGING_CATEGORY(CWSGI_BALANCER, "wsgi.tcp_server_balancer", QtWarningMsg)
41 
42 using namespace Cutelyst;
43 
44 #ifdef Q_OS_LINUX
45 int listenReuse(const QHostAddress &address, int listenQueue, quint16 port, bool reusePort, bool startListening);
46 #endif
47 
TcpServerBalancer(Server * wsgi)48 TcpServerBalancer::TcpServerBalancer(Server *wsgi) : QTcpServer(wsgi)
49   , m_wsgi(wsgi)
50 {
51 }
52 
~TcpServerBalancer()53 TcpServerBalancer::~TcpServerBalancer()
54 {
55 #ifndef QT_NO_SSL
56     delete m_sslConfiguration;
57 #endif // QT_NO_SSL
58 }
59 
listen(const QString & line,Protocol * protocol,bool secure)60 bool TcpServerBalancer::listen(const QString &line, Protocol *protocol, bool secure)
61 {
62     m_protocol = protocol;
63 
64     int commaPos = line.indexOf(QLatin1Char(','));
65     const QString addressPortString = line.mid(0, commaPos);
66 
67     QString addressString;
68     int closeBracketPos = addressPortString.indexOf(QLatin1Char(']'));
69     if (closeBracketPos != -1) {
70         if (!line.startsWith(QLatin1Char('['))) {
71             std::cerr << "Failed to parse address: " << qPrintable(addressPortString) << std::endl;
72             return false;
73         }
74         addressString = addressPortString.mid(1, closeBracketPos - 1);
75     } else {
76         addressString = addressPortString.section(QLatin1Char(':'), 0, -2);
77     }
78     const QString portString = addressPortString.section(QLatin1Char(':'), -1);
79 
80     QHostAddress address;
81     if (addressString.isEmpty()) {
82         address = QHostAddress(QHostAddress::Any);
83     } else {
84         address.setAddress(addressString);
85     }
86 
87     bool ok;
88     quint16 port = portString.toUInt(&ok);
89     if (!ok || (port < 1 || port > 35554)) {
90         port = 80;
91     }
92 
93 #ifndef QT_NO_SSL
94     if (secure) {
95         if (commaPos == -1) {
96             std::cerr << "No SSL certificate specified" << std::endl;
97             return false;
98         }
99 
100         const QString sslString = line.mid(commaPos + 1);
101         const QString certPath = sslString.section(QLatin1Char(','), 0, 0);
102         QFile certFile(certPath);
103         if (!certFile.open(QFile::ReadOnly)) {
104             std::cerr << "Failed to open SSL certificate" << qPrintable(certPath)
105                       << qPrintable(certFile.errorString()) << std::endl;
106             return false;
107         }
108         QSslCertificate cert(&certFile);
109         if (cert.isNull()) {
110             std::cerr << "Failed to parse SSL certificate" << std::endl;
111             return false;
112         }
113 
114         const QString keyPath = sslString.section(QLatin1Char(','), 1, 1);
115         QFile keyFile(keyPath);
116         if (!keyFile.open(QFile::ReadOnly)) {
117             std::cerr << "Failed to open SSL private key" << qPrintable(keyPath)
118                       << qPrintable(keyFile.errorString()) << std::endl;
119             return false;
120         }
121 
122         QSsl::KeyAlgorithm algorithm = QSsl::Rsa;
123         const QString keyAlgorithm = sslString.section(QLatin1Char(','), 2, 2);
124         if (!keyAlgorithm.isEmpty()) {
125             if (keyAlgorithm.compare(QLatin1String("rsa"), Qt::CaseInsensitive) == 0) {
126                 algorithm = QSsl::Rsa;
127             } else if (keyAlgorithm.compare(QLatin1String("ec"), Qt::CaseInsensitive) == 0) {
128                 algorithm = QSsl::Ec;
129             } else {
130                 std::cerr << "Failed to select SSL Key Algorithm" << qPrintable(keyAlgorithm) << std::endl;
131                 return false;
132             }
133         }
134 
135         QSslKey key(&keyFile, algorithm);
136         if (key.isNull()) {
137             std::cerr << "Failed to parse SSL private key" << std::endl;
138             return false;
139         }
140 
141         m_sslConfiguration = new QSslConfiguration;
142         m_sslConfiguration->setLocalCertificate(cert);
143         m_sslConfiguration->setPrivateKey(key);
144         m_sslConfiguration->setPeerVerifyMode(QSslSocket::VerifyNone); // prevent asking for client certificate
145         if (m_wsgi->httpsH2()) {
146             m_sslConfiguration->setAllowedNextProtocols({ QByteArrayLiteral("h2"), QSslConfiguration::NextProtocolHttp1_1});
147         }
148     }
149 #endif // QT_NO_SSL
150 
151     m_address = address;
152     m_port = port;
153 
154 #ifdef Q_OS_LINUX
155     int socket = listenReuse(address, m_wsgi->listenQueue(), port, m_wsgi->reusePort(), !m_wsgi->reusePort());
156     if (socket > 0 && setSocketDescriptor(socket)) {
157         pauseAccepting();
158     } else {
159         std::cerr << "Failed to listen on TCP: " << qPrintable(line)
160                   << " : " << qPrintable(errorString()) << std::endl;
161         return false;
162     }
163 #else
164     bool ret = QTcpServer::listen(address, port);
165     if (ret) {
166         pauseAccepting();
167     } else {
168         std::cerr << "Failed to listen on TCP: " << qPrintable(line)
169                   << " : " << qPrintable(errorString()) << std::endl;
170         return false;
171     }
172 #endif
173 
174     m_serverName = serverAddress().toString() + QLatin1Char(':') + QString::number(port);
175     return true;
176 }
177 
178 #ifdef Q_OS_LINUX
179 // UnixWare 7 redefines socket -> _socket
qt_safe_socket(int domain,int type,int protocol,int flags=0)180 static inline int qt_safe_socket(int domain, int type, int protocol, int flags = 0)
181 {
182     Q_ASSERT((flags & ~O_NONBLOCK) == 0);
183 
184     int fd;
185 #ifdef QT_THREADSAFE_CLOEXEC
186     int newtype = type | SOCK_CLOEXEC;
187     if (flags & O_NONBLOCK)
188         newtype |= SOCK_NONBLOCK;
189     fd = ::socket(domain, newtype, protocol);
190     return fd;
191 #else
192     fd = ::socket(domain, type, protocol);
193     if (fd == -1)
194         return -1;
195 
196     ::fcntl(fd, F_SETFD, FD_CLOEXEC);
197 
198     // set non-block too?
199     if (flags & O_NONBLOCK)
200         ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) | O_NONBLOCK);
201 
202     return fd;
203 #endif
204 }
205 
createNewSocket(QAbstractSocket::NetworkLayerProtocol & socketProtocol)206 int createNewSocket(QAbstractSocket::NetworkLayerProtocol &socketProtocol)
207 {
208     int protocol = 0;
209 
210     int domain = (socketProtocol == QAbstractSocket::IPv6Protocol
211                   || socketProtocol == QAbstractSocket::AnyIPProtocol) ? AF_INET6 : AF_INET;
212     int type = SOCK_STREAM;
213 
214     int socket = qt_safe_socket(domain, type, protocol, O_NONBLOCK);
215     if (socket < 0 && socketProtocol == QAbstractSocket::AnyIPProtocol && errno == EAFNOSUPPORT) {
216         domain = AF_INET;
217         socket = qt_safe_socket(domain, type, protocol, O_NONBLOCK);
218         socketProtocol = QAbstractSocket::IPv4Protocol;
219     }
220 
221     if (socket < 0) {
222         int ecopy = errno;
223         switch (ecopy) {
224         case EPROTONOSUPPORT:
225         case EAFNOSUPPORT:
226         case EINVAL:
227             qCDebug(CWSGI_BALANCER) << "setError(QAbstractSocket::UnsupportedSocketOperationError, ProtocolUnsupportedErrorString)";
228             break;
229         case ENFILE:
230         case EMFILE:
231         case ENOBUFS:
232         case ENOMEM:
233             qCDebug(CWSGI_BALANCER) << "setError(QAbstractSocket::SocketResourceError, ResourceErrorString)";
234             break;
235         case EACCES:
236             qCDebug(CWSGI_BALANCER) << "setError(QAbstractSocket::SocketAccessError, AccessErrorString)";
237             break;
238         default:
239             break;
240         }
241 
242 #if defined (QNATIVESOCKETENGINE_DEBUG)
243         qCDebug(CWSGI_BALANCER, "QNativeSocketEnginePrivate::createNewSocket(%d, %d) == false (%s)",
244                 socketType, socketProtocol,
245                 strerror(ecopy));
246 #endif
247 
248         return false;
249     }
250 
251 #if defined (QNATIVESOCKETENGINE_DEBUG)
252     qCDebug(CWSGI_BALANCER, "QNativeSocketEnginePrivate::createNewSocket(%d, %d) == true",
253             socketType, socketProtocol);
254 #endif
255 
256     return socket;
257 }
258 
259 union qt_sockaddr {
260     sockaddr a;
261     sockaddr_in a4;
262     sockaddr_in6 a6;
263 };
264 
265 #  define QT_SOCKLEN_T int
266 #define QT_SOCKET_BIND          ::bind
267 
268 namespace {
269 namespace SetSALen {
set(T * sa,typename std::enable_if<(& T::sa_len,true),QT_SOCKLEN_T>::type len)270     template <typename T> void set(T *sa, typename std::enable_if<(&T::sa_len, true), QT_SOCKLEN_T>::type len)
271     { sa->sa_len = len; }
set(T * sin6,typename std::enable_if<(& T::sin6_len,true),QT_SOCKLEN_T>::type len)272     template <typename T> void set(T *sin6, typename std::enable_if<(&T::sin6_len, true), QT_SOCKLEN_T>::type len)
273     { sin6->sin6_len = len; }
set(T *,...)274     template <typename T> void set(T *, ...) {}
275 }
276 }
277 
setPortAndAddress(quint16 port,const QHostAddress & address,QAbstractSocket::NetworkLayerProtocol socketProtocol,qt_sockaddr * aa,int * sockAddrSize)278 void setPortAndAddress(quint16 port, const QHostAddress &address, QAbstractSocket::NetworkLayerProtocol socketProtocol, qt_sockaddr *aa, int *sockAddrSize)
279 {
280     if (address.protocol() == QAbstractSocket::IPv6Protocol
281         || address.protocol() == QAbstractSocket::AnyIPProtocol
282         || socketProtocol == QAbstractSocket::IPv6Protocol
283         || socketProtocol == QAbstractSocket::AnyIPProtocol) {
284         memset(&aa->a6, 0, sizeof(sockaddr_in6));
285         aa->a6.sin6_family = AF_INET6;
286 //#if QT_CONFIG(networkinterface)
287 //            aa->a6.sin6_scope_id = scopeIdFromString(address.scopeId());
288 //#endif
289         aa->a6.sin6_port = htons(port);
290         Q_IPV6ADDR tmp = address.toIPv6Address();
291         memcpy(&aa->a6.sin6_addr, &tmp, sizeof(tmp));
292         *sockAddrSize = sizeof(sockaddr_in6);
293         SetSALen::set(&aa->a, sizeof(sockaddr_in6));
294     } else {
295         memset(&aa->a, 0, sizeof(sockaddr_in));
296         aa->a4.sin_family = AF_INET;
297         aa->a4.sin_port = htons(port);
298         aa->a4.sin_addr.s_addr = htonl(address.toIPv4Address());
299         *sockAddrSize = sizeof(sockaddr_in);
300         SetSALen::set(&aa->a, sizeof(sockaddr_in));
301     }
302 }
303 
nativeBind(int socketDescriptor,const QHostAddress & address,quint16 port)304 bool nativeBind(int socketDescriptor, const QHostAddress &address, quint16 port)
305 {
306     qt_sockaddr aa;
307     int sockAddrSize;
308     setPortAndAddress(port, address, address.protocol(), &aa, &sockAddrSize);
309 
310 #ifdef IPV6_V6ONLY
311     if (aa.a.sa_family == AF_INET6) {
312         int ipv6only = 0;
313         if (address.protocol() == QAbstractSocket::IPv6Protocol)
314             ipv6only = 1;
315         //default value of this socket option varies depending on unix variant (or system configuration on BSD), so always set it explicitly
316         ::setsockopt(socketDescriptor, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&ipv6only, sizeof(ipv6only) );
317     }
318 #endif
319 
320     int bindResult = ::bind(socketDescriptor, &aa.a, sockAddrSize);
321     if (bindResult < 0 && errno == EAFNOSUPPORT && address.protocol() == QAbstractSocket::AnyIPProtocol) {
322         // retry with v4
323         aa.a4.sin_family = AF_INET;
324         aa.a4.sin_port = htons(port);
325         aa.a4.sin_addr.s_addr = htonl(address.toIPv4Address());
326         sockAddrSize = sizeof(aa.a4);
327         bindResult = QT_SOCKET_BIND(socketDescriptor, &aa.a, sockAddrSize);
328     }
329 
330     if (bindResult < 0) {
331 #if defined (QNATIVESOCKETENGINE_DEBUG)
332         int ecopy = errno;
333 #endif
334 //        switch(errno) {
335 //        case EADDRINUSE:
336 //            setError(QAbstractSocket::AddressInUseError, AddressInuseErrorString);
337 //            break;
338 //        case EACCES:
339 //            setError(QAbstractSocket::SocketAccessError, AddressProtectedErrorString);
340 //            break;
341 //        case EINVAL:
342 //            setError(QAbstractSocket::UnsupportedSocketOperationError, OperationUnsupportedErrorString);
343 //            break;
344 //        case EADDRNOTAVAIL:
345 //            setError(QAbstractSocket::SocketAddressNotAvailableError, AddressNotAvailableErrorString);
346 //            break;
347 //        default:
348 //            break;
349 //        }
350 
351 #if defined (QNATIVESOCKETENGINE_DEBUG)
352         qCDebug(CWSGI_BALANCER, "QNativeSocketEnginePrivate::nativeBind(%s, %i) == false (%s)",
353                 address.toString().toLatin1().constData(), port, strerror(ecopy));
354 #endif
355 
356         return false;
357     }
358 
359 #if defined (QNATIVESOCKETENGINE_DEBUG)
360     qCDebug(CWSGI_BALANCER, "QNativeSocketEnginePrivate::nativeBind(%s, %i) == true",
361             address.toString().toLatin1().constData(), port);
362 #endif
363 //    socketState = QAbstractSocket::BoundState;
364     return true;
365 }
366 
listenReuse(const QHostAddress & address,int listenQueue,quint16 port,bool reusePort,bool startListening)367 int listenReuse(const QHostAddress &address, int listenQueue, quint16 port, bool reusePort, bool startListening)
368 {
369     QAbstractSocket::NetworkLayerProtocol proto = address.protocol();
370 
371     int socket = createNewSocket(proto);
372     if (socket < 0) {
373         qCCritical(CWSGI_BALANCER) << "Failed to create new socket";
374         return -1;
375     }
376 
377     int optval = 1;
378     // SO_REUSEADDR is set by default on QTcpServer and allows to bind again
379     // without having to wait all previous connections to close
380     if (::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval))) {
381         qCCritical(CWSGI_BALANCER) << "Failed to set SO_REUSEADDR on socket" << socket;
382         return -1;
383     }
384 
385     if (reusePort) {
386         if (::setsockopt(socket, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval))) {
387             qCCritical(CWSGI_BALANCER) << "Failed to set SO_REUSEPORT on socket" << socket;
388             return -1;
389         }
390     }
391 
392     if (!nativeBind(socket, address, port)) {
393         qCCritical(CWSGI_BALANCER) << "Failed to bind to socket" << socket;
394         return -1;
395     }
396 
397     if (startListening && ::listen(socket, listenQueue) < 0) {
398         qCCritical(CWSGI_BALANCER) << "Failed to listen to socket" << socket;
399         return -1;
400     }
401 
402     return socket;
403 }
404 #endif // Q_OS_LINUX
405 
setBalancer(bool enable)406 void TcpServerBalancer::setBalancer(bool enable)
407 {
408     m_balancer = enable;
409 }
410 
incomingConnection(qintptr handle)411 void TcpServerBalancer::incomingConnection(qintptr handle)
412 {
413     TcpServer *serverIdle = m_servers.at(m_currentServer++ % m_servers.size());
414 
415     Q_EMIT serverIdle->createConnection(handle);
416 }
417 
createServer(CWsgiEngine * engine)418 TcpServer *TcpServerBalancer::createServer(CWsgiEngine *engine)
419 {
420     TcpServer *server;
421     if (m_sslConfiguration) {
422 #ifndef QT_NO_SSL
423         auto sslServer = new TcpSslServer(m_serverName, m_protocol, m_wsgi, engine);
424         sslServer->setSslConfiguration(*m_sslConfiguration);
425         server = sslServer;
426 #endif //QT_NO_SSL
427     } else {
428         server = new TcpServer(m_serverName, m_protocol, m_wsgi, engine);
429     }
430     connect(engine, &CWsgiEngine::shutdown, server, &TcpServer::shutdown);
431 
432     if (m_balancer) {
433         connect(engine, &CWsgiEngine::started, this, [=] () {
434             m_servers.push_back(server);
435             resumeAccepting();
436         }, Qt::QueuedConnection);
437         connect(server, &TcpServer::createConnection, server, &TcpServer::incomingConnection, Qt::QueuedConnection);
438     } else {
439 
440 #ifdef Q_OS_LINUX
441         if (m_wsgi->reusePort()) {
442             connect(engine, &CWsgiEngine::started, this, [=] () {
443                 int socket = listenReuse(m_address, m_wsgi->listenQueue(), m_port, m_wsgi->reusePort(), true);
444                 if (!server->setSocketDescriptor(socket)) {
445                     qFatal("Failed to set server socket descriptor, reuse-port");
446                 }
447             }, Qt::DirectConnection);
448             return server;
449         }
450 #endif
451 
452         if (server->setSocketDescriptor(socketDescriptor())) {
453             server->pauseAccepting();
454             connect(engine, &CWsgiEngine::started, server, &TcpServer::resumeAccepting, Qt::DirectConnection);
455         } else {
456             qFatal("Failed to set server socket descriptor");
457         }
458     }
459 
460     return server;
461 }
462 
463 #include "moc_tcpserverbalancer.cpp"
464