1 /*
2  * This file is part of PowerDNS or dnsdist.
3  * Copyright -- PowerDNS.COM B.V. and its contributors
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of version 2 of the GNU General Public License as
7  * published by the Free Software Foundation.
8  *
9  * In addition, for the avoidance of any doubt, permission is granted to
10  * link this program with OpenSSL and to (re)distribute the binaries
11  * produced as the result of such linking.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  */
22 
23 #include <thread>
24 #include <netinet/tcp.h>
25 #include <queue>
26 
27 #include "dnsdist.hh"
28 #include "dnsdist-ecs.hh"
29 #include "dnsdist-proxy-protocol.hh"
30 #include "dnsdist-rings.hh"
31 #include "dnsdist-tcp-downstream.hh"
32 #include "dnsdist-tcp-upstream.hh"
33 #include "dnsdist-xpf.hh"
34 #include "dnsparser.hh"
35 #include "dolog.hh"
36 #include "gettime.hh"
37 #include "lock.hh"
38 #include "sstuff.hh"
39 #include "tcpiohandler.hh"
40 #include "tcpiohandler-mplexer.hh"
41 #include "threadname.hh"
42 
43 /* TCP: the grand design.
44    We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
45    An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
46    we will not go there.
47 
48    In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
49    This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
50    to guarantee performance.
51 
52    So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
53    So whenever an answer comes in, we know where it needs to go.
54 
55    Let's start naively.
56 */
57 
58 static std::mutex s_tcpClientsCountMutex;
59 static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> s_tcpClientsCount;
60 
61 size_t g_maxTCPQueriesPerConn{0};
62 size_t g_maxTCPConnectionDuration{0};
63 size_t g_maxTCPConnectionsPerClient{0};
64 #ifdef __linux__
65 // On Linux this gives us 128k pending queries (default is 8192 queries),
66 // which should be enough to deal with huge spikes
67 size_t g_tcpInternalPipeBufferSize{1024*1024};
68 uint64_t g_maxTCPQueuedConnections{10000};
69 #else
70 size_t g_tcpInternalPipeBufferSize{0};
71 uint64_t g_maxTCPQueuedConnections{1000};
72 #endif
73 uint16_t g_downstreamTCPCleanupInterval{60};
74 int g_tcpRecvTimeout{2};
75 int g_tcpSendTimeout{2};
76 bool g_useTCPSinglePipe{false};
77 std::atomic<uint64_t> g_tcpStatesDumpRequested{0};
78 
79 class DownstreamConnectionsManager
80 {
81 public:
82 
getConnectionToDownstream(std::unique_ptr<FDMultiplexer> & mplexer,std::shared_ptr<DownstreamState> & ds,const struct timeval & now)83   static std::shared_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
84   {
85     std::shared_ptr<TCPConnectionToBackend> result;
86     struct timeval freshCutOff = now;
87     freshCutOff.tv_sec -= 1;
88 
89     const auto& it = t_downstreamConnections.find(ds);
90     if (it != t_downstreamConnections.end()) {
91       auto& list = it->second;
92       while (!list.empty()) {
93         result = std::move(list.back());
94         list.pop_back();
95 
96         result->setReused();
97         /* for connections that have not been used very recently,
98            check whether they have been closed in the meantime */
99         if (freshCutOff < result->getLastDataReceivedTime()) {
100           /* used recently enough, skip the check */
101           ++ds->tcpReusedConnections;
102           return result;
103         }
104 
105         if (isTCPSocketUsable(result->getHandle())) {
106           ++ds->tcpReusedConnections;
107           return result;
108         }
109 
110         /* otherwise let's try the next one, if any */
111       }
112     }
113 
114     return std::make_shared<TCPConnectionToBackend>(ds, now);
115   }
116 
releaseDownstreamConnection(std::shared_ptr<TCPConnectionToBackend> && conn)117   static void releaseDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>&& conn)
118   {
119     if (conn == nullptr) {
120       return;
121     }
122 
123     if (!conn->canBeReused()) {
124       conn.reset();
125       return;
126     }
127 
128     const auto& ds = conn->getDS();
129     auto& list = t_downstreamConnections[ds];
130     while (list.size() >= s_maxCachedConnectionsPerDownstream) {
131       /* too many connections queued already */
132       list.pop_front();
133     }
134 
135     list.push_back(std::move(conn));
136   }
137 
cleanupClosedTCPConnections(struct timeval now)138   static void cleanupClosedTCPConnections(struct timeval now)
139   {
140     struct timeval freshCutOff = now;
141     freshCutOff.tv_sec -= 1;
142 
143     for (auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
144       for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
145         if (!(*connIt)) {
146           ++connIt;
147           continue;
148         }
149 
150         /* don't bother checking freshly used connections */
151         if (freshCutOff < (*connIt)->getLastDataReceivedTime()) {
152           ++connIt;
153           continue;
154         }
155 
156         if (isTCPSocketUsable((*connIt)->getHandle())) {
157           ++connIt;
158         }
159         else {
160           connIt = dsIt->second.erase(connIt);
161         }
162       }
163 
164       if (!dsIt->second.empty()) {
165         ++dsIt;
166       }
167       else {
168         dsIt = t_downstreamConnections.erase(dsIt);
169       }
170     }
171   }
172 
clear()173   static size_t clear()
174   {
175     size_t count = 0;
176     for (const auto& downstream : t_downstreamConnections) {
177       count += downstream.second.size();
178     }
179 
180     t_downstreamConnections.clear();
181 
182     return count;
183   }
184 
setMaxCachedConnectionsPerDownstream(size_t max)185   static void setMaxCachedConnectionsPerDownstream(size_t max)
186   {
187     s_maxCachedConnectionsPerDownstream = max;
188   }
189 
190 private:
191   static thread_local map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
192   static size_t s_maxCachedConnectionsPerDownstream;
193 };
194 
setMaxCachedTCPConnectionsPerDownstream(size_t max)195 void setMaxCachedTCPConnectionsPerDownstream(size_t max)
196 {
197   DownstreamConnectionsManager::setMaxCachedConnectionsPerDownstream(max);
198 }
199 
200 thread_local map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> DownstreamConnectionsManager::t_downstreamConnections;
201 size_t DownstreamConnectionsManager::s_maxCachedConnectionsPerDownstream{10};
202 
decrementTCPClientCount(const ComboAddress & client)203 static void decrementTCPClientCount(const ComboAddress& client)
204 {
205   if (g_maxTCPConnectionsPerClient) {
206     std::lock_guard<std::mutex> lock(s_tcpClientsCountMutex);
207     s_tcpClientsCount.at(client)--;
208     if (s_tcpClientsCount[client] == 0) {
209       s_tcpClientsCount.erase(client);
210     }
211   }
212 }
213 
~IncomingTCPConnectionState()214 IncomingTCPConnectionState::~IncomingTCPConnectionState()
215 {
216   decrementTCPClientCount(d_ci.remote);
217 
218   if (d_ci.cs != nullptr) {
219     struct timeval now;
220     gettimeofday(&now, nullptr);
221 
222     auto diff = now - d_connectionStartTime;
223     d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
224   }
225 
226   // would have been done when the object is destroyed anyway,
227   // but that way we make sure it's done before the ConnectionInfo is destroyed,
228   // closing the descriptor, instead of relying on the declaration order of the objects in the class
229   d_handler.close();
230 }
231 
clearAllDownstreamConnections()232 size_t IncomingTCPConnectionState::clearAllDownstreamConnections()
233 {
234   return DownstreamConnectionsManager::clear();
235 }
236 
getDownstreamConnection(std::shared_ptr<DownstreamState> & ds,const std::unique_ptr<std::vector<ProxyProtocolValue>> & tlvs,const struct timeval & now)237 std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now)
238 {
239   std::shared_ptr<TCPConnectionToBackend> downstream{nullptr};
240 
241   downstream = getActiveDownstreamConnection(ds, tlvs);
242 
243   if (!downstream) {
244     /* we don't have a connection to this backend active yet, let's get one (it might not be a fresh one, though) */
245     downstream = DownstreamConnectionsManager::getConnectionToDownstream(d_threadData.mplexer, ds, now);
246     registerActiveDownstreamConnection(downstream);
247   }
248 
249   return downstream;
250 }
251 
252 static void tcpClientThread(int pipefd);
253 
TCPClientCollection(size_t maxThreads,bool useSinglePipe)254 TCPClientCollection::TCPClientCollection(size_t maxThreads, bool useSinglePipe): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads), d_singlePipe{-1,-1}, d_useSinglePipe(useSinglePipe)
255 {
256   if (d_useSinglePipe) {
257     if (pipe(d_singlePipe) < 0) {
258       int err = errno;
259       throw std::runtime_error("Error creating the TCP single communication pipe: " + stringerror(err));
260     }
261 
262     if (!setNonBlocking(d_singlePipe[0])) {
263       int err = errno;
264       close(d_singlePipe[0]);
265       close(d_singlePipe[1]);
266       throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + stringerror(err));
267     }
268 
269     if (!setNonBlocking(d_singlePipe[1])) {
270       int err = errno;
271       close(d_singlePipe[0]);
272       close(d_singlePipe[1]);
273       throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + stringerror(err));
274     }
275 
276     if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(d_singlePipe[0]) < g_tcpInternalPipeBufferSize) {
277       setPipeBufferSize(d_singlePipe[0], g_tcpInternalPipeBufferSize);
278     }
279   }
280 }
281 
addTCPClientThread()282 void TCPClientCollection::addTCPClientThread()
283 {
284   int pipefds[2] = { -1, -1};
285 
286   vinfolog("Adding TCP Client thread");
287 
288   if (d_useSinglePipe) {
289     pipefds[0] = d_singlePipe[0];
290     pipefds[1] = d_singlePipe[1];
291   }
292   else {
293     if (pipe(pipefds) < 0) {
294       errlog("Error creating the TCP thread communication pipe: %s", stringerror());
295       return;
296     }
297 
298     if (!setNonBlocking(pipefds[0])) {
299       int err = errno;
300       close(pipefds[0]);
301       close(pipefds[1]);
302       errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
303       return;
304     }
305 
306     if (!setNonBlocking(pipefds[1])) {
307       int err = errno;
308       close(pipefds[0]);
309       close(pipefds[1]);
310       errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
311       return;
312     }
313 
314     if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(pipefds[0]) < g_tcpInternalPipeBufferSize) {
315       setPipeBufferSize(pipefds[0], g_tcpInternalPipeBufferSize);
316     }
317   }
318 
319   {
320     std::lock_guard<std::mutex> lock(d_mutex);
321 
322     if (d_numthreads >= d_tcpclientthreads.size()) {
323       vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
324       if (!d_useSinglePipe) {
325         close(pipefds[0]);
326         close(pipefds[1]);
327       }
328       return;
329     }
330 
331     try {
332       std::thread t1(tcpClientThread, pipefds[0]);
333       t1.detach();
334     }
335     catch (const std::runtime_error& e) {
336       /* the thread creation failed, don't leak */
337       errlog("Error creating a TCP thread: %s", e.what());
338       if (!d_useSinglePipe) {
339         close(pipefds[0]);
340         close(pipefds[1]);
341       }
342       return;
343     }
344 
345     d_tcpclientthreads.at(d_numthreads) = pipefds[1];
346     ++d_numthreads;
347   }
348 }
349 
350 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
351 
sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState> & state,const struct timeval & now)352 static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
353 {
354   IOState result = IOState::Done;
355 
356   while (state->active() && !state->d_queuedResponses.empty()) {
357     DEBUGLOG("queue size is "<<state->d_queuedResponses.size()<<", sending the next one");
358     TCPResponse resp = std::move(state->d_queuedResponses.front());
359     state->d_queuedResponses.pop_front();
360     state->d_state = IncomingTCPConnectionState::State::idle;
361     result = state->sendResponse(state, now, std::move(resp));
362     if (result != IOState::Done) {
363       return result;
364     }
365   }
366 
367   state->d_state = IncomingTCPConnectionState::State::idle;
368   return IOState::Done;
369 }
370 
handleResponseSent(std::shared_ptr<IncomingTCPConnectionState> & state,const TCPResponse & currentResponse)371 static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, const TCPResponse& currentResponse)
372 {
373   if (state->d_isXFR) {
374     return;
375   }
376 
377   --state->d_currentQueriesCount;
378 
379   if (currentResponse.d_selfGenerated == false && currentResponse.d_connection && currentResponse.d_connection->getDS()) {
380     const auto& ds = currentResponse.d_connection->getDS();
381     struct timespec answertime;
382     gettime(&answertime);
383     const auto& ids = currentResponse.d_idstate;
384     double udiff = ids.sentTime.udiff();
385     g_rings.insertResponse(answertime, state->d_ci.remote, ids.qname, ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ds->remote);
386     vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f usec", ds->remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff);
387   }
388 
389   switch (currentResponse.d_cleartextDH.rcode) {
390   case RCode::NXDomain:
391     ++g_stats.frontendNXDomain;
392     break;
393   case RCode::ServFail:
394     ++g_stats.servfailResponses;
395     ++g_stats.frontendServFail;
396     break;
397   case RCode::NoError:
398     ++g_stats.frontendNoError;
399     break;
400   }
401 }
402 
canAcceptNewQueries(const struct timeval & now)403 bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now)
404 {
405   if (d_hadErrors) {
406     DEBUGLOG("not accepting new queries because we encountered some error during the processing already");
407     return false;
408   }
409 
410   if (d_isXFR) {
411     DEBUGLOG("not accepting new queries because used for XFR");
412     return false;
413   }
414 
415   if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
416     DEBUGLOG("not accepting new queries because we already have "<<d_currentQueriesCount<<" out of "<<d_ci.cs->d_maxInFlightQueriesPerConn);
417     return false;
418   }
419 
420   if (g_maxTCPQueriesPerConn && d_queriesCount > g_maxTCPQueriesPerConn) {
421     vinfolog("not accepting new queries from %s because it reached the maximum number of queries per conn (%d / %d)", d_ci.remote.toStringWithPort(), d_queriesCount, g_maxTCPQueriesPerConn);
422     return false;
423   }
424 
425   if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
426     vinfolog("not accepting new queries from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
427     return false;
428   }
429 
430   return true;
431 }
432 
resetForNewQuery()433 void IncomingTCPConnectionState::resetForNewQuery()
434 {
435   d_buffer.resize(sizeof(uint16_t));
436   d_currentPos = 0;
437   d_querySize = 0;
438   d_state = State::waitingForQuery;
439 }
440 
getActiveDownstreamConnection(const std::shared_ptr<DownstreamState> & ds,const std::unique_ptr<std::vector<ProxyProtocolValue>> & tlvs)441 std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs)
442 {
443   auto it = d_activeConnectionsToBackend.find(ds);
444   if (it == d_activeConnectionsToBackend.end()) {
445     DEBUGLOG("no active connection found for "<<ds->getName());
446     return nullptr;
447   }
448 
449   for (auto& conn : it->second) {
450     if (conn->canAcceptNewQueries() && conn->matchesTLVs(tlvs)) {
451       DEBUGLOG("Got one active connection accepting more for "<<ds->getName());
452       conn->setReused();
453       return conn;
454     }
455     DEBUGLOG("not accepting more for "<<ds->getName());
456   }
457 
458   return nullptr;
459 }
460 
registerActiveDownstreamConnection(std::shared_ptr<TCPConnectionToBackend> & conn)461 void IncomingTCPConnectionState::registerActiveDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn)
462 {
463   d_activeConnectionsToBackend[conn->getDS()].push_front(conn);
464 }
465 
466 /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
sendResponse(std::shared_ptr<IncomingTCPConnectionState> & state,const struct timeval & now,TCPResponse && response)467 IOState IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
468 {
469   state->d_state = IncomingTCPConnectionState::State::sendingResponse;
470 
471   uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size());
472   const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) };
473   /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
474      that could occur if we had to deal with the size during the processing,
475      especially alignment issues */
476   response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2);
477   state->d_currentPos = 0;
478   state->d_currentResponse = std::move(response);
479 
480   try {
481     auto iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
482     if (iostate == IOState::Done) {
483       DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
484       handleResponseSent(state, state->d_currentResponse);
485       return iostate;
486     } else {
487       state->d_lastIOBlocked = true;
488       DEBUGLOG("partial write");
489       return iostate;
490     }
491   }
492   catch (const std::exception& e) {
493     vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
494     DEBUGLOG("Closing TCP client connection: "<<e.what());
495     ++state->d_ci.cs->tcpDiedSendingResponse;
496 
497     state->terminateClientConnection();
498 
499     return IOState::Done;
500   }
501 }
502 
terminateClientConnection()503 void IncomingTCPConnectionState::terminateClientConnection()
504 {
505   DEBUGLOG("terminating client connection");
506   d_queuedResponses.clear();
507   /* we have already released idle connections that could be reused,
508      we don't care about the ones still waiting for responses */
509   d_activeConnectionsToBackend.clear();
510   /* meaning we will no longer be 'active' when the backend
511      response or timeout comes in */
512   d_ioState.reset();
513   d_handler.close();
514 }
515 
queueResponse(std::shared_ptr<IncomingTCPConnectionState> & state,const struct timeval & now,TCPResponse && response)516 void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
517 {
518   // queue response
519   state->d_queuedResponses.push_back(std::move(response));
520   DEBUGLOG("queueing response, state is "<<(int)state->d_state<<", queue size is now "<<state->d_queuedResponses.size());
521 
522   // when the response comes from a backend, there is a real possibility that we are currently
523   // idle, and thus not trying to send the response right away would make our ref count go to 0.
524   // Even if we are waiting for a query, we will not wake up before the new query arrives or a
525   // timeout occurs
526   if (state->d_state == IncomingTCPConnectionState::State::idle ||
527       state->d_state == IncomingTCPConnectionState::State::waitingForQuery) {
528     auto iostate = sendQueuedResponses(state, now);
529 
530     if (iostate == IOState::Done && state->active()) {
531       if (state->canAcceptNewQueries(now)) {
532         state->resetForNewQuery();
533         state->d_state = IncomingTCPConnectionState::State::waitingForQuery;
534         iostate = IOState::NeedRead;
535       }
536       else {
537         state->d_state = IncomingTCPConnectionState::State::idle;
538       }
539     }
540 
541     // for the same reason we need to update the state right away, nobody will do that for us
542     if (state->active()) {
543       state->d_ioState->update(iostate, handleIOCallback, state, iostate == IOState::NeedWrite ? state->getClientWriteTTD(now) : state->getClientReadTTD(now));
544     }
545   }
546 }
547 
548 /* called from the backend code when a new response has been received */
handleResponse(std::shared_ptr<IncomingTCPConnectionState> state,const struct timeval & now,TCPResponse && response)549 void IncomingTCPConnectionState::handleResponse(std::shared_ptr<IncomingTCPConnectionState> state, const struct timeval& now, TCPResponse&& response)
550 {
551   if (response.d_connection && response.d_connection->isIdle()) {
552     // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool yet, no one else will be able to use it anyway
553     if (response.d_connection->canBeReused()) {
554       auto& list = state->d_activeConnectionsToBackend.at(response.d_connection->getDS());
555 
556       for (auto it = list.begin(); it != list.end(); ++it) {
557         if (*it == response.d_connection) {
558           try {
559             response.d_connection->release();
560             DownstreamConnectionsManager::releaseDownstreamConnection(std::move(*it));
561           }
562           catch (const std::exception& e) {
563             vinfolog("Error releasing connection: %s", e.what());
564           }
565           list.erase(it);
566           break;
567         }
568       }
569     }
570   }
571 
572   if (response.d_buffer.size() < sizeof(dnsheader)) {
573     state->terminateClientConnection();
574     return;
575   }
576 
577   try {
578     auto& ids = response.d_idstate;
579     unsigned int qnameWireLength;
580     if (!responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getRemote(), qnameWireLength)) {
581       state->terminateClientConnection();
582       return;
583     }
584 
585     DNSResponse dr = makeDNSResponseFromIDState(ids, response.d_buffer, true);
586 
587     memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH));
588 
589     if (!processResponse(response.d_buffer, state->d_threadData.localRespRuleActions, dr, false)) {
590       state->terminateClientConnection();
591       return;
592     }
593   }
594   catch (const std::exception& e) {
595     vinfolog("Unexpected exception while handling response from backend: %s", e.what());
596     state->terminateClientConnection();
597     return;
598   }
599 
600   ++g_stats.responses;
601   ++state->d_ci.cs->responses;
602   if (response.d_connection->getDS()) {
603     ++response.d_connection->getDS()->responses;
604   }
605 
606   queueResponse(state, now, std::move(response));
607 }
608 
handleQuery(std::shared_ptr<IncomingTCPConnectionState> & state,const struct timeval & now)609 static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
610 {
611   if (state->d_querySize < sizeof(dnsheader)) {
612     ++g_stats.nonCompliantQueries;
613     state->terminateClientConnection();
614     return;
615   }
616 
617   ++state->d_queriesCount;
618   ++state->d_ci.cs->queries;
619   ++g_stats.queries;
620 
621   if (state->d_handler.isTLS()) {
622     auto tlsVersion = state->d_handler.getTLSVersion();
623     switch (tlsVersion) {
624     case LibsslTLSVersion::TLS10:
625       ++state->d_ci.cs->tls10queries;
626       break;
627     case LibsslTLSVersion::TLS11:
628       ++state->d_ci.cs->tls11queries;
629       break;
630     case LibsslTLSVersion::TLS12:
631       ++state->d_ci.cs->tls12queries;
632       break;
633     case LibsslTLSVersion::TLS13:
634       ++state->d_ci.cs->tls13queries;
635       break;
636     default:
637       ++state->d_ci.cs->tlsUnknownqueries;
638     }
639   }
640 
641   /* we need an accurate ("real") value for the response and
642      to store into the IDS, but not for insertion into the
643      rings for example */
644   struct timespec queryRealTime;
645   gettime(&queryRealTime, true);
646 
647   std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
648   auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, dnsCryptQuery, queryRealTime.tv_sec, true);
649   if (dnsCryptResponse) {
650     TCPResponse response;
651     state->d_state = IncomingTCPConnectionState::State::idle;
652     ++state->d_currentQueriesCount;
653     state->queueResponse(state, now, std::move(response));
654     return;
655   }
656 
657   {
658     /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
659     auto* dh = reinterpret_cast<dnsheader*>(state->d_buffer.data());
660     if (!checkQueryHeaders(dh)) {
661       state->terminateClientConnection();
662       return;
663     }
664 
665     if (dh->qdcount == 0) {
666       TCPResponse response;
667       dh->rcode = RCode::NotImp;
668       dh->qr = true;
669       response.d_selfGenerated = true;
670       response.d_buffer = std::move(state->d_buffer);
671       state->d_state = IncomingTCPConnectionState::State::idle;
672       ++state->d_currentQueriesCount;
673       state->queueResponse(state, now, std::move(response));
674       return;
675     }
676   }
677 
678   uint16_t qtype, qclass;
679   unsigned int qnameWireLength = 0;
680   DNSName qname(reinterpret_cast<const char*>(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
681   DNSQuestion dq(&qname, qtype, qclass, &state->d_proxiedDestination, &state->d_proxiedRemote, state->d_buffer, true, &queryRealTime);
682   dq.dnsCryptQuery = std::move(dnsCryptQuery);
683   dq.sni = state->d_handler.getServerNameIndication();
684   if (state->d_proxyProtocolValues) {
685     /* we need to copy them, because the next queries received on that connection will
686        need to get the _unaltered_ values */
687     dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*state->d_proxyProtocolValues);
688   }
689 
690   state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
691   if (state->d_isXFR) {
692     dq.skipCache = true;
693   }
694 
695   std::shared_ptr<DownstreamState> ds;
696   auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, ds);
697 
698   if (result == ProcessQueryResult::Drop) {
699     state->terminateClientConnection();
700     return;
701   }
702 
703   // the buffer might have been invalidated by now
704   const dnsheader* dh = dq.getHeader();
705   if (result == ProcessQueryResult::SendAnswer) {
706     TCPResponse response;
707     response.d_selfGenerated = true;
708     response.d_buffer = std::move(state->d_buffer);
709     state->d_state = IncomingTCPConnectionState::State::idle;
710     ++state->d_currentQueriesCount;
711     state->queueResponse(state, now, std::move(response));
712     return;
713   }
714 
715   if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
716     state->terminateClientConnection();
717     return;
718   }
719 
720   IDState ids;
721   setIDStateFromDNSQuestion(ids, dq, std::move(qname));
722   ids.origID = ntohs(dh->id);
723 
724   uint16_t queryLen = state->d_buffer.size();
725   const uint8_t sizeBytes[] = { static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256) };
726   /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
727      that could occur if we had to deal with the size during the processing,
728      especially alignment issues */
729   state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
730 
731   auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
732   downstreamConnection->assignToClientConnection(state, state->d_isXFR);
733 
734   bool proxyProtocolPayloadAdded = false;
735   std::string proxyProtocolPayload;
736 
737   if (ds->useProxyProtocol) {
738     /* if we ever sent a TLV over a connection, we can never go back */
739     if (!state->d_proxyProtocolPayloadHasTLV) {
740       state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
741     }
742 
743     proxyProtocolPayload = getProxyProtocolPayload(dq);
744     if (state->d_proxyProtocolPayloadHasTLV && downstreamConnection->isFresh()) {
745       /* we will not be able to reuse an existing connection anyway so let's add the payload right now */
746       addProxyProtocol(state->d_buffer, proxyProtocolPayload);
747       proxyProtocolPayloadAdded = true;
748     }
749   }
750 
751   if (dq.proxyProtocolValues) {
752     downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues));
753   }
754 
755   TCPQuery query(std::move(state->d_buffer), std::move(ids));
756   if (proxyProtocolPayloadAdded) {
757     query.d_proxyProtocolPayloadAdded = true;
758   }
759   else {
760     query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
761   }
762 
763   ++state->d_currentQueriesCount;
764   vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).getName(), state->d_proxiedRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName());
765   downstreamConnection->queueQuery(std::move(query), downstreamConnection);
766 }
767 
handleIOCallback(int fd,FDMultiplexer::funcparam_t & param)768 void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
769 {
770   auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
771   if (fd != conn->d_handler.getDescriptor()) {
772     throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor()));
773   }
774 
775   struct timeval now;
776   gettimeofday(&now, nullptr);
777   handleIO(conn, now);
778 }
779 
handleIO(std::shared_ptr<IncomingTCPConnectionState> & state,const struct timeval & now)780 void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
781 {
782   // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
783   // even though the underlying socket is not ready, so we need to actually ask for the data first
784   IOState iostate = IOState::Done;
785   do {
786     iostate = IOState::Done;
787     IOStateGuard ioGuard(state->d_ioState);
788 
789     if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
790       vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
791       // will be handled by the ioGuard
792       //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
793       return;
794     }
795 
796     state->d_lastIOBlocked = false;
797 
798     try {
799       if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
800         DEBUGLOG("doing handshake");
801         iostate = state->d_handler.tryHandshake();
802         if (iostate == IOState::Done) {
803           DEBUGLOG("handshake done");
804           if (state->d_handler.isTLS()) {
805             if (!state->d_handler.hasTLSSessionBeenResumed()) {
806               ++state->d_ci.cs->tlsNewSessions;
807             }
808             else {
809               ++state->d_ci.cs->tlsResumptions;
810             }
811             if (state->d_handler.getResumedFromInactiveTicketKey()) {
812               ++state->d_ci.cs->tlsInactiveTicketKey;
813             }
814             if (state->d_handler.getUnknownTicketKey()) {
815               ++state->d_ci.cs->tlsUnknownTicketKey;
816             }
817           }
818 
819           state->d_handshakeDoneTime = now;
820           if (expectProxyProtocolFrom(state->d_ci.remote)) {
821             state->d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
822             state->d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
823             state->d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
824           }
825           else {
826             state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
827           }
828         }
829         else {
830           state->d_lastIOBlocked = true;
831         }
832       }
833 
834       if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
835         do {
836           DEBUGLOG("reading proxy protocol header");
837           iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_proxyProtocolNeed);
838           if (iostate == IOState::Done) {
839             state->d_buffer.resize(state->d_currentPos);
840             ssize_t remaining = isProxyHeaderComplete(state->d_buffer);
841             if (remaining == 0) {
842               vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", state->d_ci.remote.toStringWithPort());
843               ++g_stats.proxyProtocolInvalid;
844               break;
845             }
846             else if (remaining < 0) {
847               state->d_proxyProtocolNeed += -remaining;
848               state->d_buffer.resize(state->d_currentPos + state->d_proxyProtocolNeed);
849               /* we need to keep reading, since we might have buffered data */
850               iostate = IOState::NeedRead;
851             }
852             else {
853               /* proxy header received */
854               std::vector<ProxyProtocolValue> proxyProtocolValues;
855               if (!handleProxyProtocol(state->d_ci.remote, true, *state->d_threadData.holders.acl, state->d_buffer, state->d_proxiedRemote, state->d_proxiedDestination, proxyProtocolValues)) {
856                 vinfolog("Error handling the Proxy Protocol received from TCP client %s", state->d_ci.remote.toStringWithPort());
857                 break;
858               }
859 
860               if (!proxyProtocolValues.empty()) {
861                 state->d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
862               }
863 
864               state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
865               state->d_buffer.resize(sizeof(uint16_t));
866               state->d_currentPos = 0;
867               state->d_proxyProtocolNeed = 0;
868               break;
869             }
870           }
871           else {
872             state->d_lastIOBlocked = true;
873           }
874         }
875         while (state->active() && !state->d_lastIOBlocked);
876       }
877 
878       if (!state->d_lastIOBlocked && (state->d_state == IncomingTCPConnectionState::State::waitingForQuery ||
879                                       state->d_state == IncomingTCPConnectionState::State::readingQuerySize)) {
880         DEBUGLOG("reading query size");
881         iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
882         if (state->d_currentPos > 0) {
883           /* if we got at least one byte, we can't go around sending responses */
884           state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
885         }
886 
887         if (iostate == IOState::Done) {
888           DEBUGLOG("query size received");
889           state->d_state = IncomingTCPConnectionState::State::readingQuery;
890           state->d_querySizeReadTime = now;
891           if (state->d_queriesCount == 0) {
892             state->d_firstQuerySizeReadTime = now;
893           }
894           state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
895           if (state->d_querySize < sizeof(dnsheader)) {
896             /* go away */
897             state->terminateClientConnection();
898             return;
899           }
900 
901           /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
902              or to add ECS without allocating a new buffer */
903           state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
904           state->d_currentPos = 0;
905         }
906         else {
907           state->d_lastIOBlocked = true;
908         }
909       }
910 
911       if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingQuery) {
912         DEBUGLOG("reading query");
913         iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
914         if (iostate == IOState::Done) {
915           DEBUGLOG("query received");
916           state->d_buffer.resize(state->d_querySize);
917 
918           state->d_state = IncomingTCPConnectionState::State::idle;
919           handleQuery(state, now);
920           /* the state might have been updated in the meantime, we don't want to override it
921              in that case */
922           if (state->active() && state->d_state != IncomingTCPConnectionState::State::idle) {
923             iostate = state->d_ioState->getState();
924           }
925         }
926         else {
927           state->d_lastIOBlocked = true;
928         }
929       }
930 
931       if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
932         DEBUGLOG("sending response");
933         iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
934         if (iostate == IOState::Done) {
935           DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
936           handleResponseSent(state, state->d_currentResponse);
937           state->d_state = IncomingTCPConnectionState::State::idle;
938         }
939         else {
940           state->d_lastIOBlocked = true;
941         }
942       }
943 
944       if (state->active() &&
945           !state->d_lastIOBlocked &&
946           iostate == IOState::Done &&
947           (state->d_state == IncomingTCPConnectionState::State::idle ||
948            state->d_state == IncomingTCPConnectionState::State::waitingForQuery))
949       {
950         // try sending queued responses
951         DEBUGLOG("send responses, if any");
952         iostate = sendQueuedResponses(state, now);
953 
954         if (!state->d_lastIOBlocked && state->active() && iostate == IOState::Done) {
955           // if the query has been passed to a backend, or dropped, and the responses have been sent,
956           // we can start reading again
957           if (state->canAcceptNewQueries(now)) {
958             state->resetForNewQuery();
959             iostate = IOState::NeedRead;
960           }
961           else {
962             state->d_state = IncomingTCPConnectionState::State::idle;
963             iostate = IOState::Done;
964           }
965         }
966       }
967 
968       if (state->d_state != IncomingTCPConnectionState::State::idle &&
969           state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
970           state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader &&
971           state->d_state != IncomingTCPConnectionState::State::waitingForQuery &&
972           state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
973           state->d_state != IncomingTCPConnectionState::State::readingQuery &&
974           state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
975         vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
976       }
977     }
978     catch (const std::exception& e) {
979       /* most likely an EOF because the other end closed the connection,
980          but it might also be a real IO error or something else.
981          Let's just drop the connection
982       */
983       if (state->d_state == IncomingTCPConnectionState::State::idle ||
984           state->d_state == IncomingTCPConnectionState::State::waitingForQuery) {
985         /* no need to increase any counters in that case, the client is simply done with us */
986       }
987       else if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
988                state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
989                state->d_state == IncomingTCPConnectionState::State::waitingForQuery ||
990                state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
991                state->d_state == IncomingTCPConnectionState::State::readingQuery) {
992         ++state->d_ci.cs->tcpDiedReadingQuery;
993       }
994       else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
995         /* unlikely to happen here, the exception should be handled in sendResponse() */
996         ++state->d_ci.cs->tcpDiedSendingResponse;
997       }
998 
999       if (state->d_ioState->getState() == IOState::NeedWrite || state->d_queriesCount == 0) {
1000         DEBUGLOG("Got an exception while handling TCP query: "<<e.what());
1001         vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->getState() == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
1002       }
1003       else {
1004         vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
1005         DEBUGLOG("Closing TCP client connection: "<<e.what());
1006       }
1007       /* remove this FD from the IO multiplexer */
1008       state->terminateClientConnection();
1009     }
1010 
1011     if (!state->active()) {
1012       DEBUGLOG("state is no longer active");
1013       return;
1014     }
1015 
1016     if (iostate == IOState::Done) {
1017       state->d_ioState->update(iostate, handleIOCallback, state);
1018     }
1019     else {
1020       state->d_ioState->update(iostate, handleIOCallback, state, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
1021     }
1022     ioGuard.release();
1023   }
1024   while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !state->d_lastIOBlocked);
1025 }
1026 
notifyIOError(std::shared_ptr<IncomingTCPConnectionState> & state,IDState && query,const struct timeval & now)1027 void IncomingTCPConnectionState::notifyIOError(std::shared_ptr<IncomingTCPConnectionState>& state, IDState&& query, const struct timeval& now)
1028 {
1029   --state->d_currentQueriesCount;
1030   state->d_hadErrors = true;
1031 
1032   if (state->d_state == State::sendingResponse) {
1033     /* if we have responses to send, let's do that first */
1034   }
1035   else if (!state->d_queuedResponses.empty()) {
1036     /* stop reading and send what we have */
1037     try {
1038       auto iostate = sendQueuedResponses(state, now);
1039 
1040       if (state->active() && iostate != IOState::Done) {
1041         // we need to update the state right away, nobody will do that for us
1042         state->d_ioState->update(iostate, handleIOCallback, state, iostate == IOState::NeedWrite ? state->getClientWriteTTD(now) : state->getClientReadTTD(now));
1043       }
1044     }
1045     catch (const std::exception& e) {
1046       vinfolog("Exception in notifyIOError: %s", e.what());
1047     }
1048   }
1049   else {
1050     // the backend code already tried to reconnect if it was possible
1051     state->terminateClientConnection();
1052   }
1053 }
1054 
handleXFRResponse(std::shared_ptr<IncomingTCPConnectionState> & state,const struct timeval & now,TCPResponse && response)1055 void IncomingTCPConnectionState::handleXFRResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
1056 {
1057   queueResponse(state, now, std::move(response));
1058 }
1059 
handleTimeout(std::shared_ptr<IncomingTCPConnectionState> & state,bool write)1060 void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write)
1061 {
1062   vinfolog("Timeout while %s TCP client %s", (write ? "writing to" : "reading from"), state->d_ci.remote.toStringWithPort());
1063   DEBUGLOG("client timeout");
1064   DEBUGLOG("Processed "<<state->d_queriesCount<<" queries, current count is "<<state->d_currentQueriesCount<<", "<<state->d_activeConnectionsToBackend.size()<<" active connections, "<<state->d_queuedResponses.size()<<" response queued");
1065 
1066   if (write || state->d_currentQueriesCount == 0) {
1067     ++state->d_ci.cs->tcpClientTimeouts;
1068     state->d_ioState.reset();
1069   }
1070   else {
1071     DEBUGLOG("Going idle");
1072     /* we still have some queries in flight, let's just stop reading for now */
1073     state->d_state = IncomingTCPConnectionState::State::idle;
1074     state->d_ioState->update(IOState::Done, handleIOCallback, state);
1075 
1076 #ifdef DEBUGLOG_ENABLED
1077     for (const auto& active : state->d_activeConnectionsToBackend) {
1078       for (const auto& conn: active.second) {
1079         DEBUGLOG("Connection to "<<active.first->getName()<<" is "<<(conn->isIdle() ? "idle" : "not idle"));
1080       }
1081     }
1082 #endif
1083   }
1084 }
1085 
handleIncomingTCPQuery(int pipefd,FDMultiplexer::funcparam_t & param)1086 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1087 {
1088   auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1089 
1090   ConnectionInfo* citmp{nullptr};
1091 
1092   ssize_t got = read(pipefd, &citmp, sizeof(citmp));
1093   if (got == 0) {
1094     throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1095   }
1096   else if (got == -1) {
1097     if (errno == EAGAIN || errno == EINTR) {
1098       return;
1099     }
1100     throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
1101   }
1102   else if (got != sizeof(citmp)) {
1103     throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1104   }
1105 
1106   try {
1107     g_tcpclientthreads->decrementQueuedCount();
1108 
1109     struct timeval now;
1110     gettimeofday(&now, nullptr);
1111     auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
1112     delete citmp;
1113     citmp = nullptr;
1114 
1115     IncomingTCPConnectionState::handleIO(state, now);
1116   }
1117   catch(...) {
1118     delete citmp;
1119     citmp = nullptr;
1120     throw;
1121   }
1122 }
1123 
tcpClientThread(int pipefd)1124 static void tcpClientThread(int pipefd)
1125 {
1126   /* we get launched with a pipe on which we receive file descriptors from clients that we own
1127      from that point on */
1128 
1129   setThreadName("dnsdist/tcpClie");
1130 
1131   TCPClientThreadData data;
1132 
1133   data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
1134   struct timeval now;
1135   gettimeofday(&now, nullptr);
1136   time_t lastTCPCleanup = now.tv_sec;
1137   time_t lastTimeoutScan = now.tv_sec;
1138 
1139   for (;;) {
1140     data.mplexer->run(&now);
1141 
1142     if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
1143       DownstreamConnectionsManager::cleanupClosedTCPConnections(now);
1144       lastTCPCleanup = now.tv_sec;
1145 
1146       if (g_tcpStatesDumpRequested > 0) {
1147         /* just to keep things clean in the output, debug only */
1148         static std::mutex s_lock;
1149         std::lock_guard<decltype(s_lock)> lck(s_lock);
1150         if (g_tcpStatesDumpRequested > 0) {
1151           /* no race here, we took the lock so it can only be increased in the meantime */
1152           --g_tcpStatesDumpRequested;
1153           errlog("Dumping the TCP states, as requested:");
1154           data.mplexer->runForAllWatchedFDs([](bool isRead, int fd, const FDMultiplexer::funcparam_t& param, struct timeval ttd)
1155           {
1156             struct timeval lnow;
1157             gettimeofday(&lnow, nullptr);
1158             if (ttd.tv_sec > 0) {
1159             errlog("- Descriptor %d is in %s state, TTD in %d", fd, (isRead ? "read" : "write"), (ttd.tv_sec-lnow.tv_sec));
1160             }
1161             else {
1162               errlog("- Descriptor %d is in %s state, no TTD set", fd, (isRead ? "read" : "write"));
1163             }
1164 
1165             if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
1166               auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1167               errlog(" - %s", state->toString());
1168             }
1169             else if (param.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
1170               auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param);
1171               errlog(" - %s", conn->toString());
1172             }
1173             else if (param.type() == typeid(TCPClientThreadData*)) {
1174               errlog(" - Worker thread pipe");
1175             }
1176           });
1177         }
1178       }
1179     }
1180 
1181     if (now.tv_sec > lastTimeoutScan) {
1182       lastTimeoutScan = now.tv_sec;
1183       auto expiredReadConns = data.mplexer->getTimeouts(now, false);
1184       for (const auto& cbData : expiredReadConns) {
1185         if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
1186           auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
1187           if (cbData.first == state->d_handler.getDescriptor()) {
1188             vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1189             state->handleTimeout(state, false);
1190           }
1191         }
1192         else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
1193           auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
1194           vinfolog("Timeout (read) from remote backend %s", conn->getBackendName());
1195           conn->handleTimeout(now, false);
1196         }
1197       }
1198 
1199       auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
1200       for (const auto& cbData : expiredWriteConns) {
1201         if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
1202           auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
1203           if (cbData.first == state->d_handler.getDescriptor()) {
1204             vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1205             state->handleTimeout(state, true);
1206           }
1207         }
1208         else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
1209           auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
1210           vinfolog("Timeout (write) from remote backend %s", conn->getBackendName());
1211           conn->handleTimeout(now, true);
1212         }
1213       }
1214     }
1215   }
1216 }
1217 
1218 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1219    they will hand off to worker threads & spawn more of them if required
1220 */
tcpAcceptorThread(ClientState * cs)1221 void tcpAcceptorThread(ClientState* cs)
1222 {
1223   setThreadName("dnsdist/tcpAcce");
1224 
1225   bool tcpClientCountIncremented = false;
1226   ComboAddress remote;
1227   remote.sin4.sin_family = cs->local.sin4.sin_family;
1228 
1229   auto acl = g_ACL.getLocal();
1230   for(;;) {
1231     bool queuedCounterIncremented = false;
1232     std::unique_ptr<ConnectionInfo> ci;
1233     tcpClientCountIncremented = false;
1234     try {
1235       socklen_t remlen = remote.getSocklen();
1236       ci = std::make_unique<ConnectionInfo>(cs);
1237 #ifdef HAVE_ACCEPT4
1238       ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
1239 #else
1240       ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1241 #endif
1242       // will be decremented when the ConnectionInfo object is destroyed, no matter the reason
1243       auto concurrentConnections = ++cs->tcpCurrentConnections;
1244       if (cs->d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > cs->d_tcpConcurrentConnectionsLimit) {
1245         continue;
1246       }
1247 
1248       if (concurrentConnections > cs->tcpMaxConcurrentConnections) {
1249         cs->tcpMaxConcurrentConnections = concurrentConnections;
1250       }
1251 
1252       if (ci->fd < 0) {
1253         throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
1254       }
1255 
1256       if (!acl->match(remote)) {
1257 	++g_stats.aclDrops;
1258 	vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
1259 	continue;
1260       }
1261 
1262 #ifndef HAVE_ACCEPT4
1263       if (!setNonBlocking(ci->fd)) {
1264         continue;
1265       }
1266 #endif
1267       setTCPNoDelay(ci->fd);  // disable NAGLE
1268       if (g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
1269         vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
1270         continue;
1271       }
1272 
1273       if (g_maxTCPConnectionsPerClient) {
1274         std::lock_guard<std::mutex> lock(s_tcpClientsCountMutex);
1275 
1276         if (s_tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
1277           vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
1278           continue;
1279         }
1280         s_tcpClientsCount[remote]++;
1281         tcpClientCountIncremented = true;
1282       }
1283 
1284       vinfolog("Got TCP connection from %s", remote.toStringWithPort());
1285 
1286       ci->remote = remote;
1287       int pipe = g_tcpclientthreads->getThread();
1288       if (pipe >= 0) {
1289         queuedCounterIncremented = true;
1290         auto tmp = ci.release();
1291         try {
1292           // throws on failure
1293           writen2WithTimeout(pipe, &tmp, sizeof(tmp), 0);
1294         }
1295         catch (...) {
1296           delete tmp;
1297           tmp = nullptr;
1298           throw;
1299         }
1300       }
1301       else {
1302         g_tcpclientthreads->decrementQueuedCount();
1303         queuedCounterIncremented = false;
1304         if (tcpClientCountIncremented) {
1305           decrementTCPClientCount(remote);
1306         }
1307       }
1308     }
1309     catch (const std::exception& e) {
1310       errlog("While reading a TCP question: %s", e.what());
1311       if (tcpClientCountIncremented) {
1312         decrementTCPClientCount(remote);
1313       }
1314       if (queuedCounterIncremented) {
1315         g_tcpclientthreads->decrementQueuedCount();
1316       }
1317     }
1318     catch (...){}
1319   }
1320 }
1321