1 /*
2  * This file is part of PowerDNS or dnsdist.
3  * Copyright -- PowerDNS.COM B.V. and its contributors
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of version 2 of the GNU General Public License as
7  * published by the Free Software Foundation.
8  *
9  * In addition, for the avoidance of any doubt, permission is granted to
10  * link this program with OpenSSL and to (re)distribute the binaries
11  * produced as the result of such linking.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  */
22 
23 #include "config.h"
24 
25 #include <fstream>
26 #include <getopt.h>
27 #include <grp.h>
28 #include <limits>
29 #include <netinet/tcp.h>
30 #include <pwd.h>
31 #include <sys/resource.h>
32 #include <unistd.h>
33 
34 #if defined (__OpenBSD__) || defined(__NetBSD__)
35 // If this is not undeffed, __attribute__ wil be redefined by /usr/include/readline/rlstdc.h
36 #undef __STRICT_ANSI__
37 #include <readline/readline.h>
38 #else
39 #include <editline/readline.h>
40 #endif
41 
42 #include "dnsdist-systemd.hh"
43 #ifdef HAVE_SYSTEMD
44 #include <systemd/sd-daemon.h>
45 #endif
46 
47 #include "dnsdist.hh"
48 #include "dnsdist-cache.hh"
49 #include "dnsdist-console.hh"
50 #include "dnsdist-dynblocks.hh"
51 #include "dnsdist-ecs.hh"
52 #include "dnsdist-healthchecks.hh"
53 #include "dnsdist-lua.hh"
54 #include "dnsdist-proxy-protocol.hh"
55 #include "dnsdist-rings.hh"
56 #include "dnsdist-secpoll.hh"
57 #include "dnsdist-web.hh"
58 #include "dnsdist-xpf.hh"
59 
60 #include "base64.hh"
61 #include "delaypipe.hh"
62 #include "dolog.hh"
63 #include "dnsname.hh"
64 #include "dnsparser.hh"
65 #include "ednsoptions.hh"
66 #include "gettime.hh"
67 #include "lock.hh"
68 #include "misc.hh"
69 #include "sodcrypto.hh"
70 #include "sstuff.hh"
71 #include "threadname.hh"
72 
73 /* Known sins:
74 
75    Receiver is currently single threaded
76       not *that* bad actually, but now that we are thread safe, might want to scale
77 */
78 
79 /* the RuleAction plan
80    Set of Rules, if one matches, it leads to an Action
81    Both rules and actions could conceivably be Lua based.
82    On the C++ side, both could be inherited from a class Rule and a class Action,
83    on the Lua side we can't do that. */
84 
85 using std::thread;
86 bool g_verbose;
87 
88 struct DNSDistStats g_stats;
89 
90 uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()};
91 uint32_t g_staleCacheEntriesTTL{0};
92 bool g_syslog{true};
93 bool g_allowEmptyResponse{false};
94 
95 GlobalStateHolder<NetmaskGroup> g_ACL;
96 string g_outputBuffer;
97 
98 std::vector<std::shared_ptr<TLSFrontend>> g_tlslocals;
99 std::vector<std::shared_ptr<DOHFrontend>> g_dohlocals;
100 std::vector<std::shared_ptr<DNSCryptContext>> g_dnsCryptLocals;
101 
102 shared_ptr<BPFFilter> g_defaultBPFFilter{nullptr};
103 std::vector<std::shared_ptr<DynBPFFilter> > g_dynBPFFilters;
104 
105 std::vector<std::unique_ptr<ClientState>> g_frontends;
106 GlobalStateHolder<pools_t> g_pools;
107 size_t g_udpVectorSize{1};
108 
109 /* UDP: the grand design. Per socket we listen on for incoming queries there is one thread.
110    Then we have a bunch of connected sockets for talking to downstream servers.
111    We send directly to those sockets.
112 
113    For the return path, per downstream server we have a thread that listens to responses.
114 
115    Per socket there is an array of 2^16 states, when we send out a packet downstream, we note
116    there the original requestor and the original id. The new ID is the offset in the array.
117 
118    When an answer comes in on a socket, we look up the offset by the id, and lob it to the
119    original requestor.
120 
121    IDs are assigned by atomic increments of the socket offset.
122  */
123 
124 GlobalStateHolder<vector<DNSDistRuleAction> > g_ruleactions;
125 GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_respruleactions;
126 GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_cachehitrespruleactions;
127 GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_selfansweredrespruleactions;
128 
129 Rings g_rings;
130 QueryCount g_qcount;
131 
132 GlobalStateHolder<servers_t> g_dstates;
133 GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
134 GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
135 DNSAction::Action g_dynBlockAction = DNSAction::Action::Drop;
136 int g_udpTimeout{2};
137 
138 bool g_servFailOnNoPolicy{false};
139 bool g_truncateTC{false};
140 bool g_fixupCase{false};
141 bool g_dropEmptyQueries{false};
142 
143 std::set<std::string> g_capabilitiesToRetain;
144 
145 static size_t const s_initialUDPPacketBufferSize = s_maxPacketCacheEntrySize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
146 static_assert(s_initialUDPPacketBufferSize <= UINT16_MAX, "Packet size should fit in a uint16_t");
147 
truncateTC(PacketBuffer & packet,size_t maximumSize,unsigned int qnameWireLength)148 static void truncateTC(PacketBuffer& packet, size_t maximumSize, unsigned int qnameWireLength)
149 {
150   try
151   {
152     bool hadEDNS = false;
153     uint16_t payloadSize = 0;
154     uint16_t z = 0;
155 
156     if (g_addEDNSToSelfGeneratedResponses) {
157       hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(packet.data()), packet.size(), &payloadSize, &z);
158     }
159 
160     packet.resize(static_cast<uint16_t>(sizeof(dnsheader)+qnameWireLength+DNS_TYPE_SIZE+DNS_CLASS_SIZE));
161     struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
162     dh->ancount = dh->arcount = dh->nscount = 0;
163 
164     if (hadEDNS) {
165       addEDNS(packet, maximumSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0);
166     }
167   }
168   catch(...)
169   {
170     ++g_stats.truncFail;
171   }
172 }
173 
174 struct DelayedPacket
175 {
176   int fd;
177   PacketBuffer packet;
178   ComboAddress destination;
179   ComboAddress origDest;
operator ()DelayedPacket180   void operator()()
181   {
182     ssize_t res;
183     if(origDest.sin4.sin_family == 0) {
184       res = sendto(fd, packet.data(), packet.size(), 0, (struct sockaddr*)&destination, destination.getSocklen());
185     }
186     else {
187       res = sendfromto(fd, packet.data(), packet.size(), 0, origDest, destination);
188     }
189     if (res == -1) {
190       int err = errno;
191       vinfolog("Error sending delayed response to %s: %s", destination.toStringWithPort(), strerror(err));
192     }
193   }
194 };
195 
196 DelayPipe<DelayedPacket>* g_delay = nullptr;
197 
getTrailingData() const198 std::string DNSQuestion::getTrailingData() const
199 {
200   const char* message = reinterpret_cast<const char*>(this->getHeader());
201   const uint16_t messageLen = getDNSPacketLength(message, this->data.size());
202   return std::string(message + messageLen, this->getData().size() - messageLen);
203 }
204 
setTrailingData(const std::string & tail)205 bool DNSQuestion::setTrailingData(const std::string& tail)
206 {
207   const char* message = reinterpret_cast<const char*>(this->data.data());
208   const uint16_t messageLen = getDNSPacketLength(message, this->data.size());
209   this->data.resize(messageLen);
210   if (tail.size() > 0) {
211     if (!hasRoomFor(tail.size())) {
212       return false;
213     }
214     this->data.insert(this->data.end(), tail.begin(), tail.end());
215   }
216   return true;
217 }
218 
doLatencyStats(double udiff)219 void doLatencyStats(double udiff)
220 {
221   if(udiff < 1000) ++g_stats.latency0_1;
222   else if(udiff < 10000) ++g_stats.latency1_10;
223   else if(udiff < 50000) ++g_stats.latency10_50;
224   else if(udiff < 100000) ++g_stats.latency50_100;
225   else if(udiff < 1000000) ++g_stats.latency100_1000;
226   else ++g_stats.latencySlow;
227   g_stats.latencySum += udiff / 1000;
228 
229   auto doAvg = [](double& var, double n, double weight) {
230     var = (weight -1) * var/weight + n/weight;
231   };
232 
233   doAvg(g_stats.latencyAvg100,     udiff,     100);
234   doAvg(g_stats.latencyAvg1000,    udiff,    1000);
235   doAvg(g_stats.latencyAvg10000,   udiff,   10000);
236   doAvg(g_stats.latencyAvg1000000, udiff, 1000000);
237 }
238 
responseContentMatches(const PacketBuffer & response,const DNSName & qname,const uint16_t qtype,const uint16_t qclass,const ComboAddress & remote,unsigned int & qnameWireLength)239 bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& qnameWireLength)
240 {
241   if (response.size() < sizeof(dnsheader)) {
242     return false;
243   }
244 
245   const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response.data());
246   if (dh->qr == 0) {
247     ++g_stats.nonCompliantResponses;
248     return false;
249   }
250 
251   if (dh->qdcount == 0) {
252     if ((dh->rcode != RCode::NoError && dh->rcode != RCode::NXDomain) || g_allowEmptyResponse) {
253       return true;
254     }
255     else {
256       ++g_stats.nonCompliantResponses;
257       return false;
258     }
259   }
260 
261   uint16_t rqtype, rqclass;
262   DNSName rqname;
263   try {
264     rqname = DNSName(reinterpret_cast<const char*>(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass, &qnameWireLength);
265   }
266   catch (const std::exception& e) {
267     if(response.size() > 0 && static_cast<size_t>(response.size()) > sizeof(dnsheader)) {
268       infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what());
269     }
270     ++g_stats.nonCompliantResponses;
271     return false;
272   }
273 
274   if (rqtype != qtype || rqclass != qclass || rqname != qname) {
275     return false;
276   }
277 
278   return true;
279 }
280 
restoreFlags(struct dnsheader * dh,uint16_t origFlags)281 static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
282 {
283   static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
284   static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
285   static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
286   uint16_t* flags = getFlagsFromDNSHeader(dh);
287   /* clear the flags we are about to restore */
288   *flags &= restoreFlagsMask;
289   /* only keep the flags we want to restore */
290   origFlags &= ~restoreFlagsMask;
291   /* set the saved flags as they were */
292   *flags |= origFlags;
293 }
294 
fixUpQueryTurnedResponse(DNSQuestion & dq,const uint16_t origFlags)295 static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
296 {
297   restoreFlags(dq.getHeader(), origFlags);
298 
299   return addEDNSToQueryTurnedResponse(dq);
300 }
301 
fixUpResponse(PacketBuffer & response,const DNSName & qname,uint16_t origFlags,bool ednsAdded,bool ecsAdded,bool * zeroScope)302 static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, bool* zeroScope)
303 {
304   if (response.size() < sizeof(dnsheader)) {
305     return false;
306   }
307 
308   struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
309   restoreFlags(dh, origFlags);
310 
311   if (response.size() == sizeof(dnsheader)) {
312     return true;
313   }
314 
315   if (g_fixupCase) {
316     const auto& realname = qname.getStorage();
317     if (response.size() >= (sizeof(dnsheader) + realname.length())) {
318       memcpy(&response.at(sizeof(dnsheader)), realname.c_str(), realname.length());
319     }
320   }
321 
322   if (ednsAdded || ecsAdded) {
323     uint16_t optStart;
324     size_t optLen = 0;
325     bool last = false;
326 
327     int res = locateEDNSOptRR(response, &optStart, &optLen, &last);
328 
329     if (res == 0) {
330       if (zeroScope) { // this finds if an EDNS Client Subnet scope was set, and if it is 0
331         size_t optContentStart = 0;
332         uint16_t optContentLen = 0;
333         /* we need at least 4 bytes after the option length (family: 2, source prefix-length: 1, scope prefix-length: 1) */
334         if (isEDNSOptionInOpt(response, optStart, optLen, EDNSOptionCode::ECS, &optContentStart, &optContentLen) && optContentLen >= 4) {
335           /* see if the EDNS Client Subnet SCOPE PREFIX-LENGTH byte in position 3 is set to 0, which is the only thing
336              we care about. */
337           *zeroScope = response.at(optContentStart + 3) == 0;
338         }
339       }
340 
341       if (ednsAdded) {
342         /* we added the entire OPT RR,
343            therefore we need to remove it entirely */
344         if (last) {
345           /* simply remove the last AR */
346           response.resize(response.size() - optLen);
347           dh = reinterpret_cast<struct dnsheader*>(response.data());
348           uint16_t arcount = ntohs(dh->arcount);
349           arcount--;
350           dh->arcount = htons(arcount);
351         }
352         else {
353           /* Removing an intermediary RR could lead to compression error */
354           PacketBuffer rewrittenResponse;
355           if (rewriteResponseWithoutEDNS(response, rewrittenResponse) == 0) {
356             response = std::move(rewrittenResponse);
357           }
358           else {
359             warnlog("Error rewriting content");
360           }
361         }
362       }
363       else {
364         /* the OPT RR was already present, but without ECS,
365            we need to remove the ECS option if any */
366         if (last) {
367           /* nothing after the OPT RR, we can simply remove the
368              ECS option */
369           size_t existingOptLen = optLen;
370           removeEDNSOptionFromOPT(reinterpret_cast<char*>(&response.at(optStart)), &optLen, EDNSOptionCode::ECS);
371           response.resize(response.size() - (existingOptLen - optLen));
372         }
373         else {
374           PacketBuffer rewrittenResponse;
375           /* Removing an intermediary RR could lead to compression error */
376           if (rewriteResponseWithoutEDNSOption(response, EDNSOptionCode::ECS, rewrittenResponse) == 0) {
377             response = std::move(rewrittenResponse);
378           }
379           else {
380             warnlog("Error rewriting content");
381           }
382         }
383       }
384     }
385   }
386 
387   return true;
388 }
389 
390 #ifdef HAVE_DNSCRYPT
encryptResponse(PacketBuffer & response,size_t maximumSize,bool tcp,std::shared_ptr<DNSCryptQuery> dnsCryptQuery)391 static bool encryptResponse(PacketBuffer& response, size_t maximumSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery)
392 {
393   if (dnsCryptQuery) {
394     int res = dnsCryptQuery->encryptResponse(response, maximumSize, tcp);
395     if (res != 0) {
396       /* dropping response */
397       vinfolog("Error encrypting the response, dropping.");
398       return false;
399     }
400   }
401   return true;
402 }
403 #endif /* HAVE_DNSCRYPT */
404 
applyRulesToResponse(LocalStateHolder<vector<DNSDistResponseRuleAction>> & localRespRuleActions,DNSResponse & dr)405 static bool applyRulesToResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRuleActions, DNSResponse& dr)
406 {
407   DNSResponseAction::Action action=DNSResponseAction::Action::None;
408   std::string ruleresult;
409   for(const auto& lr : *localRespRuleActions) {
410     if(lr.d_rule->matches(&dr)) {
411       lr.d_rule->d_matches++;
412       action=(*lr.d_action)(&dr, &ruleresult);
413       switch(action) {
414       case DNSResponseAction::Action::Allow:
415         return true;
416         break;
417       case DNSResponseAction::Action::Drop:
418         return false;
419         break;
420       case DNSResponseAction::Action::HeaderModify:
421         return true;
422         break;
423       case DNSResponseAction::Action::ServFail:
424         dr.getHeader()->rcode = RCode::ServFail;
425         return true;
426         break;
427         /* non-terminal actions follow */
428       case DNSResponseAction::Action::Delay:
429         dr.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
430         break;
431       case DNSResponseAction::Action::None:
432         break;
433       }
434     }
435   }
436 
437   return true;
438 }
439 
processResponse(PacketBuffer & response,LocalStateHolder<vector<DNSDistResponseRuleAction>> & localRespRuleActions,DNSResponse & dr,bool muted)440 bool processResponse(PacketBuffer& response, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRuleActions, DNSResponse& dr, bool muted)
441 {
442   if (!applyRulesToResponse(localRespRuleActions, dr)) {
443     return false;
444   }
445 
446   bool zeroScope = false;
447   if (!fixUpResponse(response, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, dr.useZeroScope ? &zeroScope : nullptr)) {
448     return false;
449   }
450 
451   if (dr.packetCache && !dr.skipCache && response.size() <= s_maxPacketCacheEntrySize) {
452     if (!dr.useZeroScope) {
453       /* if the query was not suitable for zero-scope, for
454          example because it had an existing ECS entry so the hash is
455          not really 'no ECS', so just insert it for the existing subnet
456          since:
457          - we don't have the correct hash for a non-ECS query
458          - inserting with hash computed before the ECS replacement but with
459          the subnet extracted _after_ the replacement would not work.
460       */
461       zeroScope = false;
462     }
463     // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
464     dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.cacheFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, response, dr.tcp, dr.getHeader()->rcode, dr.tempFailureTTL);
465   }
466 
467 #ifdef HAVE_DNSCRYPT
468   if (!muted) {
469     if (!encryptResponse(response, dr.getMaximumSize(), dr.tcp, dr.dnsCryptQuery)) {
470       return false;
471     }
472   }
473 #endif /* HAVE_DNSCRYPT */
474 
475   return true;
476 }
477 
getInitialUDPPacketBufferSize()478 static size_t getInitialUDPPacketBufferSize()
479 {
480   static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "The incoming buffer size should not be larger than s_initialUDPPacketBufferSize");
481 
482   if (g_proxyProtocolACL.empty()) {
483     return s_initialUDPPacketBufferSize;
484   }
485 
486   return s_initialUDPPacketBufferSize + g_proxyProtocolMaximumSize;
487 }
488 
getMaximumIncomingPacketSize(const ClientState & cs)489 static size_t getMaximumIncomingPacketSize(const ClientState& cs)
490 {
491   if (cs.dnscryptCtx) {
492     return getInitialUDPPacketBufferSize();
493   }
494 
495   if (g_proxyProtocolACL.empty()) {
496     return s_udpIncomingBufferSize;
497   }
498 
499   return s_udpIncomingBufferSize + g_proxyProtocolMaximumSize;
500 }
501 
sendUDPResponse(int origFD,const PacketBuffer & response,const int delayMsec,const ComboAddress & origDest,const ComboAddress & origRemote)502 static bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
503 {
504   if(delayMsec && g_delay) {
505     DelayedPacket dp{origFD, response, origRemote, origDest};
506     g_delay->submit(dp, delayMsec);
507   }
508   else {
509     ssize_t res;
510     if (origDest.sin4.sin_family == 0) {
511       res = sendto(origFD, response.data(), response.size(), 0, reinterpret_cast<const struct sockaddr*>(&origRemote), origRemote.getSocklen());
512     }
513     else {
514       res = sendfromto(origFD, response.data(), response.size(), 0, origDest, origRemote);
515     }
516     if (res == -1) {
517       int err = errno;
518       vinfolog("Error sending response to %s: %s", origRemote.toStringWithPort(), stringerror(err));
519     }
520   }
521 
522   return true;
523 }
524 
pickBackendSocketForSending(std::shared_ptr<DownstreamState> & state)525 int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state)
526 {
527   return state->sockets[state->socketsOffset++ % state->sockets.size()];
528 }
529 
pickBackendSocketsReadyForReceiving(const std::shared_ptr<DownstreamState> & state,std::vector<int> & ready)530 static void pickBackendSocketsReadyForReceiving(const std::shared_ptr<DownstreamState>& state, std::vector<int>& ready)
531 {
532   ready.clear();
533 
534   if (state->sockets.size() == 1) {
535     ready.push_back(state->sockets[0]);
536     return ;
537   }
538 
539   {
540     std::lock_guard<std::mutex> lock(state->socketsLock);
541     state->mplexer->getAvailableFDs(ready, 1000);
542   }
543 }
544 
545 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
responderThread(std::shared_ptr<DownstreamState> dss)546 void responderThread(std::shared_ptr<DownstreamState> dss)
547 {
548   try {
549   setThreadName("dnsdist/respond");
550   auto localRespRuleActions = g_respruleactions.getLocal();
551   const size_t initialBufferSize = getInitialUDPPacketBufferSize();
552   PacketBuffer response(initialBufferSize);
553 
554   /* when the answer is encrypted in place, we need to get a copy
555      of the original header before encryption to fill the ring buffer */
556   dnsheader cleartextDH;
557   uint16_t queryId = 0;
558   std::vector<int> sockets;
559   sockets.reserve(dss->sockets.size());
560 
561   for(;;) {
562     try {
563       pickBackendSocketsReadyForReceiving(dss, sockets);
564       if (dss->isStopped()) {
565         break;
566       }
567 
568       for (const auto& fd : sockets) {
569         response.resize(initialBufferSize);
570         ssize_t got = recv(fd, response.data(), response.size(), 0);
571 
572         if (got == 0 && dss->isStopped()) {
573           break;
574         }
575 
576         if (got < 0 || static_cast<size_t>(got) < sizeof(dnsheader)) {
577           continue;
578         }
579 
580         response.resize(static_cast<size_t>(got));
581         dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
582         queryId = dh->id;
583 
584         if (queryId >= dss->idStates.size()) {
585           continue;
586         }
587 
588         IDState* ids = &dss->idStates[queryId];
589         int64_t usageIndicator = ids->usageIndicator;
590 
591         if (!IDState::isInUse(usageIndicator)) {
592           /* the corresponding state is marked as not in use, meaning that:
593              - it was already cleaned up by another thread and the state is gone ;
594              - we already got a response for this query and this one is a duplicate.
595              Either way, we don't touch it.
596           */
597           continue;
598         }
599 
600         /* read the potential DOHUnit state as soon as possible, but don't use it
601            until we have confirmed that we own this state by updating usageIndicator */
602         auto du = ids->du;
603         /* setting age to 0 to prevent the maintainer thread from
604            cleaning this IDS while we process the response.
605         */
606         ids->age = 0;
607         int origFD = ids->origFD;
608 
609         unsigned int qnameWireLength = 0;
610         if (!responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss->remote, qnameWireLength)) {
611           continue;
612         }
613 
614         bool isDoH = du != nullptr;
615         /* atomically mark the state as available, but only if it has not been altered
616            in the meantime */
617         if (ids->tryMarkUnused(usageIndicator)) {
618           /* clear the potential DOHUnit asap, it's ours now
619            and since we just marked the state as unused,
620            someone could overwrite it. */
621           ids->du = nullptr;
622           /* we only decrement the outstanding counter if the value was not
623              altered in the meantime, which would mean that the state has been actively reused
624              and the other thread has not incremented the outstanding counter, so we don't
625              want it to be decremented twice. */
626           --dss->outstanding;  // you'd think an attacker could game this, but we're using connected socket
627         } else {
628           /* someone updated the state in the meantime, we can't touch the existing pointer */
629           du = nullptr;
630           /* since the state has been updated, we can't safely access it so let's just drop
631              this response */
632           continue;
633         }
634 
635         dh->id = ids->origID;
636 
637         DNSResponse dr = makeDNSResponseFromIDState(*ids, response, false);
638         dr.du = du;
639         if (dh->tc && g_truncateTC) {
640           truncateTC(response, dr.getMaximumSize(), qnameWireLength);
641         }
642         memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
643 
644         if (!processResponse(response, localRespRuleActions, dr, ids->cs && ids->cs->muted)) {
645           continue;
646         }
647 
648         if (ids->cs && !ids->cs->muted) {
649           if (du) {
650             dr.du = nullptr;
651 #ifdef HAVE_DNS_OVER_HTTPS
652             // DoH query
653             du->response = std::move(response);
654             static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
655             ssize_t sent = write(du->rsock, &du, sizeof(du));
656             if (sent != sizeof(du)) {
657               if (errno == EAGAIN || errno == EWOULDBLOCK) {
658                 ++g_stats.dohResponsePipeFull;
659                 vinfolog("Unable to pass a DoH response to the DoH worker thread because the pipe is full");
660               }
661               else {
662                 vinfolog("Unable to pass a DoH response to the DoH worker thread because we couldn't write to the pipe: %s", stringerror());
663               }
664 
665               /* at this point we have the only remaining pointer on this
666                  DOHUnit object since we did set ids->du to nullptr earlier,
667                  except if we got the response before the pointer could be
668                  released by the frontend */
669               du->release();
670             }
671 #endif /* HAVE_DNS_OVER_HTTPS */
672             du = nullptr;
673           }
674           else {
675             ComboAddress empty;
676             empty.sin4.sin_family = 0;
677             sendUDPResponse(origFD, response, dr.delayMsec, ids->hopLocal, ids->hopRemote);
678           }
679         }
680 
681         ++g_stats.responses;
682         if (ids->cs) {
683           ++ids->cs->responses;
684         }
685         ++dss->responses;
686 
687         double udiff = ids->sentTime.udiff();
688         vinfolog("Got answer from %s, relayed to %s%s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(),
689                  isDoH ? " (https)": "", udiff);
690 
691         struct timespec ts;
692         gettime(&ts);
693         g_rings.insertResponse(ts, *dr.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(got), cleartextDH, dss->remote);
694 
695         switch (cleartextDH.rcode) {
696         case RCode::NXDomain:
697           ++g_stats.frontendNXDomain;
698           break;
699         case RCode::ServFail:
700           ++g_stats.servfailResponses;
701           ++g_stats.frontendServFail;
702           break;
703         case RCode::NoError:
704           ++g_stats.frontendNoError;
705           break;
706         }
707         dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff/128.0;
708 
709         doLatencyStats(udiff);
710       }
711     }
712     catch (const std::exception& e){
713       vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->remote.toStringWithPort(), queryId, e.what());
714     }
715   }
716 }
717 catch (const std::exception& e)
718 {
719   errlog("UDP responder thread died because of exception: %s", e.what());
720 }
721 catch (const PDNSException& e)
722 {
723   errlog("UDP responder thread died because of PowerDNS exception: %s", e.reason);
724 }
725 catch (...)
726 {
727   errlog("UDP responder thread died because of an exception: %s", "unknown");
728 }
729 }
730 
731 std::mutex g_luamutex;
732 LuaContext g_lua;
733 ComboAddress g_serverControl{"127.0.0.1:5199"};
734 
735 
spoofResponseFromString(DNSQuestion & dq,const string & spoofContent,bool raw)736 static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent, bool raw)
737 {
738   string result;
739 
740   if (raw) {
741     std::vector<std::string> raws;
742     stringtok(raws, spoofContent, ",");
743     SpoofAction sa(raws);
744     sa(&dq, &result);
745   }
746   else {
747     std::vector<std::string> addrs;
748     stringtok(addrs, spoofContent, " ,");
749 
750     if (addrs.size() == 1) {
751       try {
752         ComboAddress spoofAddr(spoofContent);
753         SpoofAction sa({spoofAddr});
754         sa(&dq, &result);
755       }
756       catch(const PDNSException &e) {
757         DNSName cname(spoofContent);
758         SpoofAction sa(cname); // CNAME then
759         sa(&dq, &result);
760       }
761     } else {
762       std::vector<ComboAddress> cas;
763       for (const auto& addr : addrs) {
764         try {
765           cas.push_back(ComboAddress(addr));
766         }
767         catch (...) {
768         }
769       }
770       SpoofAction sa(cas);
771       sa(&dq, &result);
772     }
773   }
774 }
775 
processRulesResult(const DNSAction::Action & action,DNSQuestion & dq,std::string & ruleresult,bool & drop)776 bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop)
777 {
778   switch(action) {
779   case DNSAction::Action::Allow:
780     return true;
781     break;
782   case DNSAction::Action::Drop:
783     ++g_stats.ruleDrop;
784     drop = true;
785     return true;
786     break;
787   case DNSAction::Action::Nxdomain:
788     dq.getHeader()->rcode = RCode::NXDomain;
789     dq.getHeader()->qr=true;
790     ++g_stats.ruleNXDomain;
791     return true;
792     break;
793   case DNSAction::Action::Refused:
794     dq.getHeader()->rcode = RCode::Refused;
795     dq.getHeader()->qr=true;
796     ++g_stats.ruleRefused;
797     return true;
798     break;
799   case DNSAction::Action::ServFail:
800     dq.getHeader()->rcode = RCode::ServFail;
801     dq.getHeader()->qr=true;
802     ++g_stats.ruleServFail;
803     return true;
804     break;
805   case DNSAction::Action::Spoof:
806     spoofResponseFromString(dq, ruleresult, false);
807     return true;
808     break;
809   case DNSAction::Action::SpoofRaw:
810     spoofResponseFromString(dq, ruleresult, true);
811     return true;
812     break;
813   case DNSAction::Action::Truncate:
814     dq.getHeader()->tc = true;
815     dq.getHeader()->qr = true;
816     dq.getHeader()->ra = dq.getHeader()->rd;
817     dq.getHeader()->aa = false;
818     dq.getHeader()->ad = false;
819     ++g_stats.ruleTruncated;
820     return true;
821     break;
822   case DNSAction::Action::HeaderModify:
823     return true;
824     break;
825   case DNSAction::Action::Pool:
826     dq.poolname=ruleresult;
827     return true;
828     break;
829   case DNSAction::Action::NoRecurse:
830     dq.getHeader()->rd = false;
831     return true;
832     break;
833     /* non-terminal actions follow */
834   case DNSAction::Action::Delay:
835     dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
836     break;
837   case DNSAction::Action::None:
838     /* fall-through */
839   case DNSAction::Action::NoOp:
840     break;
841   }
842 
843   /* false means that we don't stop the processing */
844   return false;
845 }
846 
847 
applyRulesToQuery(LocalHolders & holders,DNSQuestion & dq,const struct timespec & now)848 static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now)
849 {
850   g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.getData().size(), *dq.getHeader());
851 
852   if(g_qcount.enabled) {
853     string qname = (*dq.qname).toLogString();
854     bool countQuery{true};
855     if(g_qcount.filter) {
856       std::lock_guard<std::mutex> lock(g_luamutex);
857       std::tie (countQuery, qname) = g_qcount.filter(&dq);
858     }
859 
860     if(countQuery) {
861       WriteLock wl(&g_qcount.queryLock);
862       if(!g_qcount.records.count(qname)) {
863         g_qcount.records[qname] = 0;
864       }
865       g_qcount.records[qname]++;
866     }
867   }
868 
869   if(auto got = holders.dynNMGBlock->lookup(*dq.remote)) {
870     auto updateBlockStats = [&got]() {
871       ++g_stats.dynBlocked;
872       got->second.blocks++;
873     };
874 
875     if(now < got->second.until) {
876       DNSAction::Action action = got->second.action;
877       if (action == DNSAction::Action::None) {
878         action = g_dynBlockAction;
879       }
880       switch (action) {
881       case DNSAction::Action::NoOp:
882         /* do nothing */
883         break;
884 
885       case DNSAction::Action::Nxdomain:
886         vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort());
887         updateBlockStats();
888 
889         dq.getHeader()->rcode = RCode::NXDomain;
890         dq.getHeader()->qr=true;
891         return true;
892 
893       case DNSAction::Action::Refused:
894         vinfolog("Query from %s refused because of dynamic block", dq.remote->toStringWithPort());
895         updateBlockStats();
896 
897         dq.getHeader()->rcode = RCode::Refused;
898         dq.getHeader()->qr = true;
899         return true;
900 
901       case DNSAction::Action::Truncate:
902         if(!dq.tcp) {
903           updateBlockStats();
904           vinfolog("Query from %s truncated because of dynamic block", dq.remote->toStringWithPort());
905           dq.getHeader()->tc = true;
906           dq.getHeader()->qr = true;
907           dq.getHeader()->ra = dq.getHeader()->rd;
908           dq.getHeader()->aa = false;
909           dq.getHeader()->ad = false;
910           return true;
911         }
912         else {
913           vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
914         }
915         break;
916       case DNSAction::Action::NoRecurse:
917         updateBlockStats();
918         vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort());
919         dq.getHeader()->rd = false;
920         return true;
921       default:
922         updateBlockStats();
923         vinfolog("Query from %s dropped because of dynamic block", dq.remote->toStringWithPort());
924         return false;
925       }
926     }
927   }
928 
929   if(auto got = holders.dynSMTBlock->lookup(*dq.qname)) {
930     auto updateBlockStats = [&got]() {
931       ++g_stats.dynBlocked;
932       got->blocks++;
933     };
934 
935     if(now < got->until) {
936       DNSAction::Action action = got->action;
937       if (action == DNSAction::Action::None) {
938         action = g_dynBlockAction;
939       }
940       switch (action) {
941       case DNSAction::Action::NoOp:
942         /* do nothing */
943         break;
944       case DNSAction::Action::Nxdomain:
945         vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
946         updateBlockStats();
947 
948         dq.getHeader()->rcode = RCode::NXDomain;
949         dq.getHeader()->qr=true;
950         return true;
951       case DNSAction::Action::Refused:
952         vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
953         updateBlockStats();
954 
955         dq.getHeader()->rcode = RCode::Refused;
956         dq.getHeader()->qr=true;
957         return true;
958       case DNSAction::Action::Truncate:
959         if(!dq.tcp) {
960           updateBlockStats();
961 
962           vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
963           dq.getHeader()->tc = true;
964           dq.getHeader()->qr = true;
965           dq.getHeader()->ra = dq.getHeader()->rd;
966           dq.getHeader()->aa = false;
967           dq.getHeader()->ad = false;
968           return true;
969         }
970         else {
971           vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
972         }
973         break;
974       case DNSAction::Action::NoRecurse:
975         updateBlockStats();
976         vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort());
977         dq.getHeader()->rd = false;
978         return true;
979       default:
980         updateBlockStats();
981         vinfolog("Query from %s for %s dropped because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
982         return false;
983       }
984     }
985   }
986 
987   DNSAction::Action action=DNSAction::Action::None;
988   string ruleresult;
989   bool drop = false;
990   for(const auto& lr : *holders.ruleactions) {
991     if(lr.d_rule->matches(&dq)) {
992       lr.d_rule->d_matches++;
993       action=(*lr.d_action)(&dq, &ruleresult);
994       if (processRulesResult(action, dq, ruleresult, drop)) {
995         break;
996       }
997     }
998   }
999 
1000   if (drop) {
1001     return false;
1002   }
1003 
1004   return true;
1005 }
1006 
udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState> & ss,const int sd,const PacketBuffer & request,bool healthCheck)1007 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck)
1008 {
1009   ssize_t result;
1010 
1011   if (ss->sourceItf == 0) {
1012     result = send(sd, request.data(), request.size(), 0);
1013   }
1014   else {
1015     struct msghdr msgh;
1016     struct iovec iov;
1017     cmsgbuf_aligned cbuf;
1018     ComboAddress remote(ss->remote);
1019     fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), const_cast<char*>(reinterpret_cast<const char *>(request.data())), request.size(), &remote);
1020     addCMsgSrcAddr(&msgh, &cbuf, &ss->sourceAddr, ss->sourceItf);
1021     result = sendmsg(sd, &msgh, 0);
1022   }
1023 
1024   if (result == -1) {
1025     int savederrno = errno;
1026     vinfolog("Error sending request to backend %s: %d", ss->remote.toStringWithPort(), savederrno);
1027 
1028     /* This might sound silly, but on Linux send() might fail with EINVAL
1029        if the interface the socket was bound to doesn't exist anymore.
1030        We don't want to reconnect the real socket if the healthcheck failed,
1031        because it's not using the same socket.
1032     */
1033     if (!healthCheck && (savederrno == EINVAL || savederrno == ENODEV)) {
1034       ss->reconnect();
1035     }
1036   }
1037 
1038   return result;
1039 }
1040 
isUDPQueryAcceptable(ClientState & cs,LocalHolders & holders,const struct msghdr * msgh,const ComboAddress & remote,ComboAddress & dest,bool & expectProxyProtocol)1041 static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, bool& expectProxyProtocol)
1042 {
1043   if (msgh->msg_flags & MSG_TRUNC) {
1044     /* message was too large for our buffer */
1045     vinfolog("Dropping message too large for our buffer");
1046     ++g_stats.nonCompliantQueries;
1047     return false;
1048   }
1049 
1050   expectProxyProtocol = expectProxyProtocolFrom(remote);
1051   if (!holders.acl->match(remote) && !expectProxyProtocol) {
1052     vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort());
1053     ++g_stats.aclDrops;
1054     return false;
1055   }
1056 
1057   if (HarvestDestinationAddress(msgh, &dest)) {
1058     /* we don't get the port, only the address */
1059     dest.sin4.sin_port = cs.local.sin4.sin_port;
1060   }
1061   else {
1062     dest.sin4.sin_family = 0;
1063   }
1064 
1065   cs.queries++;
1066   ++g_stats.queries;
1067 
1068   return true;
1069 }
1070 
checkDNSCryptQuery(const ClientState & cs,PacketBuffer & query,std::shared_ptr<DNSCryptQuery> & dnsCryptQuery,time_t now,bool tcp)1071 bool checkDNSCryptQuery(const ClientState& cs, PacketBuffer& query, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
1072 {
1073   if (cs.dnscryptCtx) {
1074 #ifdef HAVE_DNSCRYPT
1075     PacketBuffer response;
1076     dnsCryptQuery = std::make_shared<DNSCryptQuery>(cs.dnscryptCtx);
1077 
1078     bool decrypted = handleDNSCryptQuery(query, dnsCryptQuery, tcp, now, response);
1079 
1080     if (!decrypted) {
1081       if (response.size() > 0) {
1082         query = std::move(response);
1083         return true;
1084       }
1085       throw std::runtime_error("Unable to decrypt DNSCrypt query, dropping.");
1086     }
1087 #endif /* HAVE_DNSCRYPT */
1088   }
1089   return false;
1090 }
1091 
checkQueryHeaders(const struct dnsheader * dh)1092 bool checkQueryHeaders(const struct dnsheader* dh)
1093 {
1094   if (dh->qr) {   // don't respond to responses
1095     ++g_stats.nonCompliantQueries;
1096     return false;
1097   }
1098 
1099   if (dh->qdcount == 0) {
1100     ++g_stats.emptyQueries;
1101     if (g_dropEmptyQueries) {
1102       return false;
1103     }
1104   }
1105 
1106   if (dh->rd) {
1107     ++g_stats.rdQueries;
1108   }
1109 
1110   return true;
1111 }
1112 
1113 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
queueResponse(const ClientState & cs,const PacketBuffer & response,const ComboAddress & dest,const ComboAddress & remote,struct mmsghdr & outMsg,struct iovec * iov,cmsgbuf_aligned * cbuf)1114 static void queueResponse(const ClientState& cs, const PacketBuffer& response, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, cmsgbuf_aligned* cbuf)
1115 {
1116   outMsg.msg_len = 0;
1117   fillMSGHdr(&outMsg.msg_hdr, iov, nullptr, 0, const_cast<char*>(reinterpret_cast<const char *>(&response.at(0))), response.size(), const_cast<ComboAddress*>(&remote));
1118 
1119   if (dest.sin4.sin_family == 0) {
1120     outMsg.msg_hdr.msg_control = nullptr;
1121   }
1122   else {
1123     addCMsgSrcAddr(&outMsg.msg_hdr, cbuf, &dest, 0);
1124   }
1125 }
1126 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1127 
1128 /* self-generated responses or cache hits */
prepareOutgoingResponse(LocalHolders & holders,ClientState & cs,DNSQuestion & dq,bool cacheHit)1129 static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit)
1130 {
1131   DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, dq.getMutableData(), dq.tcp, dq.queryTime);
1132 
1133   dr.uniqueId = dq.uniqueId;
1134   dr.qTag = dq.qTag;
1135   dr.delayMsec = dq.delayMsec;
1136 
1137   if (!applyRulesToResponse(cacheHit ? holders.cacheHitRespRuleactions : holders.selfAnsweredRespRuleactions, dr)) {
1138     return false;
1139   }
1140 
1141   /* in case a rule changed it */
1142   dq.delayMsec = dr.delayMsec;
1143 
1144 #ifdef HAVE_DNSCRYPT
1145   if (!cs.muted) {
1146     if (!encryptResponse(dq.getMutableData(), dq.getMaximumSize(), dq.tcp, dq.dnsCryptQuery)) {
1147       return false;
1148     }
1149   }
1150 #endif /* HAVE_DNSCRYPT */
1151 
1152   if (cacheHit) {
1153     ++g_stats.cacheHits;
1154   }
1155 
1156   switch (dr.getHeader()->rcode) {
1157   case RCode::NXDomain:
1158     ++g_stats.frontendNXDomain;
1159     break;
1160   case RCode::ServFail:
1161     ++g_stats.frontendServFail;
1162     break;
1163   case RCode::NoError:
1164     ++g_stats.frontendNoError;
1165     break;
1166   }
1167 
1168   doLatencyStats(0);  // we're not going to measure this
1169   return true;
1170 }
1171 
processQuery(DNSQuestion & dq,ClientState & cs,LocalHolders & holders,std::shared_ptr<DownstreamState> & selectedBackend)1172 ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
1173 {
1174   const uint16_t queryId = ntohs(dq.getHeader()->id);
1175 
1176   try {
1177     /* we need an accurate ("real") value for the response and
1178        to store into the IDS, but not for insertion into the
1179        rings for example */
1180     struct timespec now;
1181     gettime(&now);
1182 
1183     if (!applyRulesToQuery(holders, dq, now)) {
1184       return ProcessQueryResult::Drop;
1185     }
1186 
1187     if (dq.getHeader()->qr) { // something turned it into a response
1188       fixUpQueryTurnedResponse(dq, dq.origFlags);
1189 
1190       if (!prepareOutgoingResponse(holders, cs, dq, false)) {
1191         return ProcessQueryResult::Drop;
1192       }
1193 
1194       ++g_stats.selfAnswered;
1195       ++cs.responses;
1196       return ProcessQueryResult::SendAnswer;
1197     }
1198 
1199     std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, dq.poolname);
1200     std::shared_ptr<ServerPolicy> poolPolicy = serverPool->policy;
1201     dq.packetCache = serverPool->packetCache;
1202     const auto& policy = poolPolicy != nullptr ? *poolPolicy : *(holders.policy);
1203     const auto servers = serverPool->getServers();
1204     selectedBackend = policy.getSelectedBackend(*servers, dq);
1205 
1206     uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL;
1207 
1208     if (dq.packetCache && !dq.skipCache) {
1209       dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
1210     }
1211 
1212     if (dq.useECS && ((selectedBackend && selectedBackend->useECS) || (!selectedBackend && serverPool->getECS()))) {
1213       // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
1214       // we need ECS parsing (parseECS) to be true so we can be sure that the initial incoming query did not have an existing
1215       // ECS option, which would make it unsuitable for the zero-scope feature.
1216       if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
1217         if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) {
1218 
1219           if (!prepareOutgoingResponse(holders, cs, dq, true)) {
1220             return ProcessQueryResult::Drop;
1221           }
1222 
1223           return ProcessQueryResult::SendAnswer;
1224         }
1225 
1226         if (!dq.subnet) {
1227           /* there was no existing ECS on the query, enable the zero-scope feature */
1228           dq.useZeroScope = true;
1229         }
1230       }
1231 
1232       if (!handleEDNSClientSubnet(dq, dq.ednsAdded, dq.ecsAdded)) {
1233         vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
1234         return ProcessQueryResult::Drop;
1235       }
1236     }
1237 
1238     if (dq.packetCache && !dq.skipCache) {
1239       if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) {
1240 
1241         restoreFlags(dq.getHeader(), dq.origFlags);
1242 
1243         if (!prepareOutgoingResponse(holders, cs, dq, true)) {
1244           return ProcessQueryResult::Drop;
1245         }
1246 
1247         return ProcessQueryResult::SendAnswer;
1248       }
1249       ++g_stats.cacheMisses;
1250     }
1251 
1252     if (!selectedBackend) {
1253       ++g_stats.noPolicy;
1254 
1255       vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toLogString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
1256       if (g_servFailOnNoPolicy) {
1257         dq.getHeader()->rcode = RCode::ServFail;
1258         dq.getHeader()->qr = true;
1259 
1260         fixUpQueryTurnedResponse(dq, dq.origFlags);
1261 
1262         if (!prepareOutgoingResponse(holders, cs, dq, false)) {
1263           return ProcessQueryResult::Drop;
1264         }
1265         // no response-only statistics counter to update.
1266         return ProcessQueryResult::SendAnswer;
1267       }
1268 
1269       return ProcessQueryResult::Drop;
1270     }
1271 
1272     /* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */
1273     dq.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader());
1274 
1275     if (dq.addXPF && selectedBackend->xpfRRCode != 0) {
1276       addXPF(dq, selectedBackend->xpfRRCode);
1277     }
1278 
1279     selectedBackend->incQueriesCount();
1280     return ProcessQueryResult::PassToBackend;
1281   }
1282   catch (const std::exception& e){
1283     vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what());
1284   }
1285   return ProcessQueryResult::Drop;
1286 }
1287 
processUDPQuery(ClientState & cs,LocalHolders & holders,const struct msghdr * msgh,const ComboAddress & remote,ComboAddress & dest,PacketBuffer & query,struct mmsghdr * responsesVect,unsigned int * queuedResponses,struct iovec * respIOV,cmsgbuf_aligned * respCBuf)1288 static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, PacketBuffer& query, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, cmsgbuf_aligned* respCBuf)
1289 {
1290   assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
1291   uint16_t queryId = 0;
1292   ComboAddress proxiedRemote = remote;
1293   ComboAddress proxiedDestination = dest;
1294 
1295   try {
1296     bool expectProxyProtocol = false;
1297     if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest, expectProxyProtocol)) {
1298       return;
1299     }
1300     /* dest might have been updated, if we managed to harvest the destination address */
1301     proxiedDestination = dest;
1302 
1303     std::vector<ProxyProtocolValue> proxyProtocolValues;
1304     if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, proxiedRemote, proxiedDestination, proxyProtocolValues)) {
1305       return;
1306     }
1307 
1308     /* we need an accurate ("real") value for the response and
1309        to store into the IDS, but not for insertion into the
1310        rings for example */
1311     struct timespec queryRealTime;
1312     gettime(&queryRealTime, true);
1313 
1314     std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
1315     auto dnsCryptResponse = checkDNSCryptQuery(cs, query, dnsCryptQuery, queryRealTime.tv_sec, false);
1316     if (dnsCryptResponse) {
1317       sendUDPResponse(cs.udpFD, query, 0, dest, remote);
1318       return;
1319     }
1320 
1321     {
1322       /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
1323       struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query.data());
1324       queryId = ntohs(dh->id);
1325 
1326       if (!checkQueryHeaders(dh)) {
1327         return;
1328       }
1329 
1330       if (dh->qdcount == 0) {
1331         dh->rcode = RCode::NotImp;
1332         dh->qr = true;
1333         sendUDPResponse(cs.udpFD, query, 0, dest, remote);
1334         return;
1335       }
1336     }
1337 
1338     uint16_t qtype, qclass;
1339     unsigned int qnameWireLength = 0;
1340     DNSName qname(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
1341     DNSQuestion dq(&qname, qtype, qclass, proxiedDestination.sin4.sin_family != 0 ? &proxiedDestination : &cs.local, &proxiedRemote, query, false, &queryRealTime);
1342     dq.dnsCryptQuery = std::move(dnsCryptQuery);
1343     if (!proxyProtocolValues.empty()) {
1344       dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
1345     }
1346     dq.hopRemote = &remote;
1347     dq.hopLocal = &dest;
1348     std::shared_ptr<DownstreamState> ss{nullptr};
1349     auto result = processQuery(dq, cs, holders, ss);
1350 
1351     if (result == ProcessQueryResult::Drop) {
1352       return;
1353     }
1354 
1355     // the buffer might have been invalidated by now (resized)
1356     struct dnsheader* dh = dq.getHeader();
1357     if (result == ProcessQueryResult::SendAnswer) {
1358 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
1359       if (dq.delayMsec == 0 && responsesVect != nullptr) {
1360         queueResponse(cs, query, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
1361         (*queuedResponses)++;
1362         return;
1363       }
1364 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1365       /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */
1366       sendUDPResponse(cs.udpFD, query, dq.delayMsec, dest, remote);
1367       return;
1368     }
1369 
1370     if (result != ProcessQueryResult::PassToBackend || ss == nullptr) {
1371       return;
1372     }
1373 
1374     unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
1375     IDState* ids = &ss->idStates[idOffset];
1376     ids->age = 0;
1377     DOHUnit* du = nullptr;
1378 
1379     /* that means that the state was in use, possibly with an allocated
1380        DOHUnit that we will need to handle, but we can't touch it before
1381        confirming that we now own this state */
1382     if (ids->isInUse()) {
1383       du = ids->du;
1384     }
1385 
1386     /* we atomically replace the value, we now own this state */
1387     if (!ids->markAsUsed()) {
1388       /* the state was not in use.
1389          we reset 'du' because it might have still been in use when we read it. */
1390       du = nullptr;
1391       ++ss->outstanding;
1392     }
1393     else {
1394       /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
1395          to handle it because it's about to be overwritten. */
1396       ids->du = nullptr;
1397       ++ss->reuseds;
1398       ++g_stats.downstreamTimeouts;
1399       handleDOHTimeout(du);
1400     }
1401 
1402     ids->cs = &cs;
1403     ids->origFD = cs.udpFD;
1404     ids->origID = dh->id;
1405     setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
1406 
1407     if (dest.sin4.sin_family != 0) {
1408       ids->origDest = dest;
1409     }
1410     else {
1411       ids->origDest = cs.local;
1412     }
1413 
1414     dh = dq.getHeader();
1415     dh->id = idOffset;
1416 
1417     if (ss->useProxyProtocol) {
1418       addProxyProtocol(dq);
1419     }
1420 
1421     int fd = pickBackendSocketForSending(ss);
1422     ssize_t ret = udpClientSendRequestToBackend(ss, fd, query);
1423 
1424     if(ret < 0) {
1425       ++ss->sendErrors;
1426       ++g_stats.downstreamSendErrors;
1427     }
1428 
1429     vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toLogString(), QType(ids->qtype).getName(), proxiedRemote.toStringWithPort(), ss->getName());
1430   }
1431   catch(const std::exception& e){
1432     vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", proxiedRemote.toStringWithPort(), queryId, e.what());
1433   }
1434 }
1435 
1436 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
MultipleMessagesUDPClientThread(ClientState * cs,LocalHolders & holders)1437 static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders)
1438 {
1439   struct MMReceiver
1440   {
1441     PacketBuffer packet;
1442     ComboAddress remote;
1443     ComboAddress dest;
1444     struct iovec iov;
1445     /* used by HarvestDestinationAddress */
1446     cmsgbuf_aligned cbuf;
1447   };
1448   const size_t vectSize = g_udpVectorSize;
1449 
1450   auto recvData = std::unique_ptr<MMReceiver[]>(new MMReceiver[vectSize]);
1451   auto msgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
1452   auto outMsgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
1453 
1454   /* the actual buffer is larger because:
1455      - we may have to add EDNS and/or ECS
1456      - we use it for self-generated responses (from rule or cache)
1457      but we only accept incoming payloads up to that size
1458   */
1459   const size_t initialBufferSize = getInitialUDPPacketBufferSize();
1460   const size_t maxIncomingPacketSize = getMaximumIncomingPacketSize(*cs);
1461 
1462   /* initialize the structures needed to receive our messages */
1463   for (size_t idx = 0; idx < vectSize; idx++) {
1464     recvData[idx].remote.sin4.sin_family = cs->local.sin4.sin_family;
1465     recvData[idx].packet.resize(initialBufferSize);
1466     fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, &recvData[idx].cbuf, sizeof(recvData[idx].cbuf), reinterpret_cast<char*>(&recvData[idx].packet.at(0)), maxIncomingPacketSize, &recvData[idx].remote);
1467   }
1468 
1469   /* go now */
1470   for(;;) {
1471 
1472     /* reset the IO vector, since it's also used to send the vector of responses
1473        to avoid having to copy the data around */
1474     for (size_t idx = 0; idx < vectSize; idx++) {
1475       recvData[idx].packet.resize(initialBufferSize);
1476       recvData[idx].iov.iov_base = &recvData[idx].packet.at(0);
1477       recvData[idx].iov.iov_len = recvData[idx].packet.size();
1478     }
1479 
1480     /* block until we have at least one message ready, but return
1481        as many as possible to save the syscall costs */
1482     int msgsGot = recvmmsg(cs->udpFD, msgVec.get(), vectSize, MSG_WAITFORONE | MSG_TRUNC, nullptr);
1483 
1484     if (msgsGot <= 0) {
1485       vinfolog("Getting UDP messages via recvmmsg() failed with: %s", stringerror());
1486       continue;
1487     }
1488 
1489     unsigned int msgsToSend = 0;
1490 
1491     /* process the received messages */
1492     for (int msgIdx = 0; msgIdx < msgsGot; msgIdx++) {
1493       const struct msghdr* msgh = &msgVec[msgIdx].msg_hdr;
1494       unsigned int got = msgVec[msgIdx].msg_len;
1495       const ComboAddress& remote = recvData[msgIdx].remote;
1496 
1497       if (static_cast<size_t>(got) < sizeof(struct dnsheader)) {
1498         ++g_stats.nonCompliantQueries;
1499         continue;
1500       }
1501 
1502       recvData[msgIdx].packet.resize(got);
1503       processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, &recvData[msgIdx].cbuf);
1504     }
1505 
1506     /* immediate (not delayed or sent to a backend) responses (mostly from a rule, dynamic block
1507        or the cache) can be sent in batch too */
1508 
1509     if (msgsToSend > 0 && msgsToSend <= static_cast<unsigned int>(msgsGot)) {
1510       int sent = sendmmsg(cs->udpFD, outMsgVec.get(), msgsToSend, 0);
1511 
1512       if (sent < 0 || static_cast<unsigned int>(sent) != msgsToSend) {
1513         vinfolog("Error sending responses with sendmmsg() (%d on %u): %s", sent, msgsToSend, stringerror());
1514       }
1515     }
1516 
1517   }
1518 }
1519 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1520 
1521 // listens to incoming queries, sends out to downstream servers, noting the intended return path
udpClientThread(ClientState * cs)1522 static void udpClientThread(ClientState* cs)
1523 {
1524   try {
1525     setThreadName("dnsdist/udpClie");
1526     LocalHolders holders;
1527 
1528 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
1529     if (g_udpVectorSize > 1) {
1530       MultipleMessagesUDPClientThread(cs, holders);
1531     }
1532     else
1533 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1534     {
1535       /* the actual buffer is larger because:
1536          - we may have to add EDNS and/or ECS
1537          - we use it for self-generated responses (from rule or cache)
1538          but we only accept incoming payloads up to that size
1539       */
1540       const size_t initialBufferSize = getInitialUDPPacketBufferSize();
1541       const size_t maxIncomingPacketSize = getMaximumIncomingPacketSize(*cs);
1542       PacketBuffer packet(initialBufferSize);
1543 
1544       struct msghdr msgh;
1545       struct iovec iov;
1546       /* used by HarvestDestinationAddress */
1547       cmsgbuf_aligned cbuf;
1548 
1549       ComboAddress remote;
1550       ComboAddress dest;
1551       remote.sin4.sin_family = cs->local.sin4.sin_family;
1552       fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), reinterpret_cast<char*>(&packet.at(0)), maxIncomingPacketSize, &remote);
1553 
1554       for(;;) {
1555         packet.resize(initialBufferSize);
1556         iov.iov_base = &packet.at(0);
1557         iov.iov_len = packet.size();
1558 
1559         ssize_t got = recvmsg(cs->udpFD, &msgh, 0);
1560 
1561         if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
1562           ++g_stats.nonCompliantQueries;
1563           continue;
1564         }
1565 
1566         packet.resize(static_cast<size_t>(got));
1567 
1568         processUDPQuery(*cs, holders, &msgh, remote, dest, packet, nullptr, nullptr, nullptr, nullptr);
1569       }
1570     }
1571   }
1572   catch(const std::exception &e)
1573   {
1574     errlog("UDP client thread died because of exception: %s", e.what());
1575   }
1576   catch(const PDNSException &e)
1577   {
1578     errlog("UDP client thread died because of PowerDNS exception: %s", e.reason);
1579   }
1580   catch(...)
1581   {
1582     errlog("UDP client thread died because of an exception: %s", "unknown");
1583   }
1584 }
1585 
1586 
getRandomDNSID()1587 uint16_t getRandomDNSID()
1588 {
1589 #ifdef HAVE_LIBSODIUM
1590   return randombytes_uniform(65536);
1591 #else
1592   return (random() % 65536);
1593 #endif
1594 }
1595 
1596 boost::optional<uint64_t> g_maxTCPClientThreads{boost::none};
1597 pdns::stat16_t g_cacheCleaningDelay{60};
1598 pdns::stat16_t g_cacheCleaningPercentage{100};
1599 
maintThread()1600 static void maintThread()
1601 {
1602   setThreadName("dnsdist/main");
1603   int interval = 1;
1604   size_t counter = 0;
1605   int32_t secondsToWaitLog = 0;
1606 
1607   for (;;) {
1608     sleep(interval);
1609 
1610     {
1611       std::lock_guard<std::mutex> lock(g_luamutex);
1612       auto f = g_lua.readVariable<boost::optional<std::function<void()> > >("maintenance");
1613       if (f) {
1614         try {
1615           (*f)();
1616           secondsToWaitLog = 0;
1617         }
1618         catch(const std::exception &e) {
1619           if (secondsToWaitLog <= 0) {
1620             infolog("Error during execution of maintenance function: %s", e.what());
1621             secondsToWaitLog = 61;
1622           }
1623           secondsToWaitLog -= interval;
1624         }
1625       }
1626     }
1627 
1628     counter++;
1629     if (counter >= g_cacheCleaningDelay) {
1630       /* keep track, for each cache, of whether we should keep
1631        expired entries */
1632       std::map<std::shared_ptr<DNSDistPacketCache>, bool> caches;
1633 
1634       /* gather all caches actually used by at least one pool, and see
1635          if something prevents us from cleaning the expired entries */
1636       auto localPools = g_pools.getLocal();
1637       for (const auto& entry : *localPools) {
1638         auto& pool = entry.second;
1639 
1640         auto packetCache = pool->packetCache;
1641         if (!packetCache) {
1642           continue;
1643         }
1644 
1645         auto pair = caches.insert({packetCache, false});
1646         auto& iter = pair.first;
1647         /* if we need to keep stale data for this cache (ie, not clear
1648            expired entries when at least one pool using this cache
1649            has all its backends down) */
1650         if (packetCache->keepStaleData() && iter->second == false) {
1651           /* so far all pools had at least one backend up */
1652           if (pool->countServers(true) == 0) {
1653             iter->second = true;
1654           }
1655         }
1656       }
1657 
1658       const time_t now = time(nullptr);
1659       for (auto pair : caches) {
1660         /* shall we keep expired entries ? */
1661         if (pair.second == true) {
1662           continue;
1663         }
1664         auto& packetCache = pair.first;
1665         size_t upTo = (packetCache->getMaxEntries()* (100 - g_cacheCleaningPercentage)) / 100;
1666         packetCache->purgeExpired(upTo, now);
1667       }
1668       counter = 0;
1669     }
1670   }
1671 }
1672 
dynBlockMaintenanceThread()1673 static void dynBlockMaintenanceThread()
1674 {
1675   setThreadName("dnsdist/dynBloc");
1676 
1677   DynBlockMaintenance::run();
1678 }
1679 
secPollThread()1680 static void secPollThread()
1681 {
1682   setThreadName("dnsdist/secpoll");
1683 
1684   for (;;) {
1685     try {
1686       doSecPoll(g_secPollSuffix);
1687     }
1688     catch(...) {
1689     }
1690     sleep(g_secPollInterval);
1691   }
1692 }
1693 
healthChecksThread()1694 static void healthChecksThread()
1695 {
1696   setThreadName("dnsdist/healthC");
1697 
1698   static const int interval = 1;
1699 
1700   for(;;) {
1701     sleep(interval);
1702 
1703     auto mplexer = std::shared_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
1704     auto states = g_dstates.getLocal(); // this points to the actual shared_ptrs!
1705     for(auto& dss : *states) {
1706       if (++dss->lastCheck < dss->checkInterval) {
1707         continue;
1708       }
1709 
1710       dss->lastCheck = 0;
1711 
1712       if (dss->availability == DownstreamState::Availability::Auto) {
1713         if (!queueHealthCheck(mplexer, dss)) {
1714           updateHealthCheckResult(dss, false);
1715         }
1716       }
1717 
1718       auto delta = dss->sw.udiffAndSet()/1000000.0;
1719       dss->queryLoad = 1.0*(dss->queries.load() - dss->prev.queries.load())/delta;
1720       dss->dropRate = 1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta;
1721       dss->prev.queries.store(dss->queries.load());
1722       dss->prev.reuseds.store(dss->reuseds.load());
1723 
1724       for (IDState& ids  : dss->idStates) { // timeouts
1725         int64_t usageIndicator = ids.usageIndicator;
1726         if(IDState::isInUse(usageIndicator) && ids.age++ > g_udpTimeout) {
1727           /* We mark the state as unused as soon as possible
1728              to limit the risk of racing with the
1729              responder thread.
1730           */
1731           auto oldDU = ids.du;
1732 
1733           if (!ids.tryMarkUnused(usageIndicator)) {
1734             /* this state has been altered in the meantime,
1735                don't go anywhere near it */
1736             continue;
1737           }
1738           ids.du = nullptr;
1739           handleDOHTimeout(oldDU);
1740           ids.age = 0;
1741           dss->reuseds++;
1742           --dss->outstanding;
1743           ++g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout
1744           vinfolog("Had a downstream timeout from %s (%s) for query for %s|%s from %s",
1745                    dss->remote.toStringWithPort(), dss->getName(),
1746                    ids.qname.toLogString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
1747 
1748           struct timespec ts;
1749           gettime(&ts);
1750 
1751           struct dnsheader fake;
1752           memset(&fake, 0, sizeof(fake));
1753           fake.id = ids.origID;
1754 
1755           g_rings.insertResponse(ts, ids.origRemote, ids.qname, ids.qtype, std::numeric_limits<unsigned int>::max(), 0, fake, dss->remote);
1756         }
1757       }
1758     }
1759 
1760     handleQueuedHealthChecks(mplexer);
1761   }
1762 }
1763 
bindAny(int af,int sock)1764 static void bindAny(int af, int sock)
1765 {
1766   __attribute__((unused)) int one = 1;
1767 
1768 #ifdef IP_FREEBIND
1769   if (setsockopt(sock, IPPROTO_IP, IP_FREEBIND, &one, sizeof(one)) < 0)
1770     warnlog("Warning: IP_FREEBIND setsockopt failed: %s", stringerror());
1771 #endif
1772 
1773 #ifdef IP_BINDANY
1774   if (af == AF_INET)
1775     if (setsockopt(sock, IPPROTO_IP, IP_BINDANY, &one, sizeof(one)) < 0)
1776       warnlog("Warning: IP_BINDANY setsockopt failed: %s", stringerror());
1777 #endif
1778 #ifdef IPV6_BINDANY
1779   if (af == AF_INET6)
1780     if (setsockopt(sock, IPPROTO_IPV6, IPV6_BINDANY, &one, sizeof(one)) < 0)
1781       warnlog("Warning: IPV6_BINDANY setsockopt failed: %s", stringerror());
1782 #endif
1783 #ifdef SO_BINDANY
1784   if (setsockopt(sock, SOL_SOCKET, SO_BINDANY, &one, sizeof(one)) < 0)
1785     warnlog("Warning: SO_BINDANY setsockopt failed: %s", stringerror());
1786 #endif
1787 }
1788 
dropGroupPrivs(gid_t gid)1789 static void dropGroupPrivs(gid_t gid)
1790 {
1791   if (gid) {
1792     if (setgid(gid) == 0) {
1793       if (setgroups(0, NULL) < 0) {
1794         warnlog("Warning: Unable to drop supplementary gids: %s", stringerror());
1795       }
1796     }
1797     else {
1798       warnlog("Warning: Unable to set group ID to %d: %s", gid, stringerror());
1799     }
1800   }
1801 }
1802 
dropUserPrivs(uid_t uid)1803 static void dropUserPrivs(uid_t uid)
1804 {
1805   if(uid) {
1806     if(setuid(uid) < 0) {
1807       warnlog("Warning: Unable to set user ID to %d: %s", uid, stringerror());
1808     }
1809   }
1810 }
1811 
checkFileDescriptorsLimits(size_t udpBindsCount,size_t tcpBindsCount)1812 static void checkFileDescriptorsLimits(size_t udpBindsCount, size_t tcpBindsCount)
1813 {
1814   /* stdin, stdout, stderr */
1815   size_t requiredFDsCount = 3;
1816   auto backends = g_dstates.getLocal();
1817   /* UDP sockets to backends */
1818   size_t backendUDPSocketsCount = 0;
1819   for (const auto& backend : *backends) {
1820     backendUDPSocketsCount += backend->sockets.size();
1821   }
1822   requiredFDsCount += backendUDPSocketsCount;
1823   /* TCP sockets to backends */
1824   if (g_maxTCPClientThreads) {
1825     requiredFDsCount += (backends->size() * (*g_maxTCPClientThreads));
1826   }
1827   /* listening sockets */
1828   requiredFDsCount += udpBindsCount;
1829   requiredFDsCount += tcpBindsCount;
1830   /* number of TCP connections currently served, assuming 1 connection per worker thread which is of course not right */
1831   if (g_maxTCPClientThreads) {
1832     requiredFDsCount += *g_maxTCPClientThreads;
1833     /* max pipes for communicating between TCP acceptors and client threads */
1834     requiredFDsCount += (*g_maxTCPClientThreads * 2);
1835   }
1836   /* max TCP queued connections */
1837   requiredFDsCount += g_maxTCPQueuedConnections;
1838   /* DelayPipe pipe */
1839   requiredFDsCount += 2;
1840   /* syslog socket */
1841   requiredFDsCount++;
1842   /* webserver main socket */
1843   requiredFDsCount++;
1844   /* console main socket */
1845   requiredFDsCount++;
1846   /* carbon export */
1847   requiredFDsCount++;
1848   /* history file */
1849   requiredFDsCount++;
1850   struct rlimit rl;
1851   getrlimit(RLIMIT_NOFILE, &rl);
1852   if (rl.rlim_cur <= requiredFDsCount) {
1853     warnlog("Warning, this configuration can use more than %d file descriptors, web server and console connections not included, and the current limit is %d.", std::to_string(requiredFDsCount), std::to_string(rl.rlim_cur));
1854 #ifdef HAVE_SYSTEMD
1855     warnlog("You can increase this value by using LimitNOFILE= in the systemd unit file or ulimit.");
1856 #else
1857     warnlog("You can increase this value by using ulimit.");
1858 #endif
1859   }
1860 }
1861 
setUpLocalBind(std::unique_ptr<ClientState> & cs)1862 static void setUpLocalBind(std::unique_ptr<ClientState>& cs)
1863 {
1864   /* skip some warnings if there is an identical UDP context */
1865   bool warn = cs->tcp == false || cs->tlsFrontend != nullptr || cs->dohFrontend != nullptr;
1866   int& fd = cs->tcp == false ? cs->udpFD : cs->tcpFD;
1867   (void) warn;
1868 
1869   fd = SSocket(cs->local.sin4.sin_family, cs->tcp == false ? SOCK_DGRAM : SOCK_STREAM, 0);
1870 
1871   if (cs->tcp) {
1872     SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
1873 #ifdef TCP_DEFER_ACCEPT
1874     SSetsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1);
1875 #endif
1876     if (cs->fastOpenQueueSize > 0) {
1877 #ifdef TCP_FASTOPEN
1878       SSetsockopt(fd, IPPROTO_TCP, TCP_FASTOPEN, cs->fastOpenQueueSize);
1879 #else
1880       if (warn) {
1881         warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
1882       }
1883 #endif
1884     }
1885   }
1886 
1887   if(cs->local.sin4.sin_family == AF_INET6) {
1888     SSetsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, 1);
1889   }
1890 
1891   bindAny(cs->local.sin4.sin_family, fd);
1892 
1893   if(!cs->tcp && IsAnyAddress(cs->local)) {
1894     int one=1;
1895     setsockopt(fd, IPPROTO_IP, GEN_IP_PKTINFO, &one, sizeof(one));     // linux supports this, so why not - might fail on other systems
1896 #ifdef IPV6_RECVPKTINFO
1897     setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one));
1898 #endif
1899   }
1900 
1901   if (cs->reuseport) {
1902     if (!setReusePort(fd)) {
1903       if (warn) {
1904         /* no need to warn again if configured but support is not available, we already did for UDP */
1905         warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
1906       }
1907     }
1908   }
1909 
1910   /* Only set this on IPv4 UDP sockets.
1911      Don't set it for DNSCrypt binds. DNSCrypt pads queries for privacy
1912      purposes, so we do receive large, sometimes fragmented datagrams. */
1913   if (!cs->tcp && !cs->dnscryptCtx) {
1914     try {
1915       setSocketIgnorePMTU(cs->udpFD, cs->local.sin4.sin_family);
1916     }
1917     catch(const std::exception& e) {
1918       warnlog("Failed to set IP_MTU_DISCOVER on UDP server socket for local address '%s': %s", cs->local.toStringWithPort(), e.what());
1919     }
1920   }
1921 
1922   const std::string& itf = cs->interface;
1923   if (!itf.empty()) {
1924 #ifdef SO_BINDTODEVICE
1925     int res = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length());
1926     if (res != 0) {
1927       warnlog("Error setting up the interface on local address '%s': %s", cs->local.toStringWithPort(), stringerror());
1928     }
1929 #else
1930     if (warn) {
1931       warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", cs->local.toStringWithPort());
1932     }
1933 #endif
1934   }
1935 
1936 #ifdef HAVE_EBPF
1937   if (g_defaultBPFFilter) {
1938     cs->attachFilter(g_defaultBPFFilter);
1939     vinfolog("Attaching default BPF Filter to %s frontend %s", (!cs->tcp ? "UDP" : "TCP"), cs->local.toStringWithPort());
1940   }
1941 #endif /* HAVE_EBPF */
1942 
1943   if (cs->tlsFrontend != nullptr) {
1944     if (!cs->tlsFrontend->setupTLS()) {
1945       errlog("Error while setting up TLS on local address '%s', exiting", cs->local.toStringWithPort());
1946       _exit(EXIT_FAILURE);
1947     }
1948   }
1949 
1950   if (cs->dohFrontend != nullptr) {
1951     cs->dohFrontend->setup();
1952   }
1953 
1954   SBind(fd, cs->local);
1955 
1956   if (cs->tcp) {
1957     SListen(cs->tcpFD, cs->tcpListenQueueSize);
1958 
1959     if (cs->tlsFrontend != nullptr) {
1960       warnlog("Listening on %s for TLS", cs->local.toStringWithPort());
1961     }
1962     else if (cs->dohFrontend != nullptr) {
1963       warnlog("Listening on %s for DoH", cs->local.toStringWithPort());
1964     }
1965     else if (cs->dnscryptCtx != nullptr) {
1966       warnlog("Listening on %s for DNSCrypt", cs->local.toStringWithPort());
1967     }
1968     else {
1969       warnlog("Listening on %s", cs->local.toStringWithPort());
1970     }
1971   }
1972 
1973   cs->ready = true;
1974 }
1975 
1976 struct
1977 {
1978   vector<string> locals;
1979   vector<string> remotes;
1980   bool checkConfig{false};
1981   bool beClient{false};
1982   bool beSupervised{false};
1983   string command;
1984   string config;
1985   string uid;
1986   string gid;
1987 } g_cmdLine;
1988 
1989 std::atomic<bool> g_configurationDone{false};
1990 
usage()1991 static void usage()
1992 {
1993   cout<<endl;
1994   cout<<"Syntax: dnsdist [-C,--config file] [-c,--client [IP[:PORT]]]\n";
1995   cout<<"[-e,--execute cmd] [-h,--help] [-l,--local addr]\n";
1996   cout<<"[-v,--verbose] [--check-config] [--version]\n";
1997   cout<<"\n";
1998   cout<<"-a,--acl netmask      Add this netmask to the ACL\n";
1999   cout<<"-C,--config file      Load configuration from 'file'\n";
2000   cout<<"-c,--client           Operate as a client, connect to dnsdist. This reads\n";
2001   cout<<"                      controlSocket from your configuration file, but also\n";
2002   cout<<"                      accepts an IP:PORT argument\n";
2003 #ifdef HAVE_LIBSODIUM
2004   cout<<"-k,--setkey KEY       Use KEY for encrypted communication to dnsdist. This\n";
2005   cout<<"                      is similar to setting setKey in the configuration file.\n";
2006   cout<<"                      NOTE: this will leak this key in your shell's history\n";
2007   cout<<"                      and in the systems running process list.\n";
2008 #endif
2009   cout<<"--check-config        Validate the configuration file and exit. The exit-code\n";
2010   cout<<"                      reflects the validation, 0 is OK, 1 means an error.\n";
2011   cout<<"                      Any errors are printed as well.\n";
2012   cout<<"-e,--execute cmd      Connect to dnsdist and execute 'cmd'\n";
2013   cout<<"-g,--gid gid          Change the process group ID after binding sockets\n";
2014   cout<<"-h,--help             Display this helpful message\n";
2015   cout<<"-l,--local address    Listen on this local address\n";
2016   cout<<"--supervised          Don't open a console, I'm supervised\n";
2017   cout<<"                        (use with e.g. systemd and daemontools)\n";
2018   cout<<"--disable-syslog      Don't log to syslog, only to stdout\n";
2019   cout<<"                        (use with e.g. systemd)\n";
2020   cout<<"-u,--uid uid          Change the process user ID after binding sockets\n";
2021   cout<<"-v,--verbose          Enable verbose mode\n";
2022   cout<<"-V,--version          Show dnsdist version information and exit\n";
2023 }
2024 
main(int argc,char ** argv)2025 int main(int argc, char** argv)
2026 {
2027   try {
2028     size_t udpBindsCount = 0;
2029     size_t tcpBindsCount = 0;
2030     rl_attempted_completion_function = my_completion;
2031     rl_completion_append_character = 0;
2032 
2033     signal(SIGPIPE, SIG_IGN);
2034     signal(SIGCHLD, SIG_IGN);
2035     openlog("dnsdist", LOG_PID|LOG_NDELAY, LOG_DAEMON);
2036 
2037 #ifdef HAVE_LIBSODIUM
2038     if (sodium_init() == -1) {
2039       cerr<<"Unable to initialize crypto library"<<endl;
2040       exit(EXIT_FAILURE);
2041     }
2042     g_hashperturb=randombytes_uniform(0xffffffff);
2043     srandom(randombytes_uniform(0xffffffff));
2044 #else
2045     {
2046       struct timeval tv;
2047       gettimeofday(&tv, 0);
2048       srandom(tv.tv_sec ^ tv.tv_usec ^ getpid());
2049       g_hashperturb=random();
2050     }
2051 
2052 #endif
2053     ComboAddress clientAddress = ComboAddress();
2054     g_cmdLine.config=SYSCONFDIR "/dnsdist.conf";
2055     struct option longopts[]={
2056       {"acl", required_argument, 0, 'a'},
2057       {"check-config", no_argument, 0, 1},
2058       {"client", no_argument, 0, 'c'},
2059       {"config", required_argument, 0, 'C'},
2060       {"disable-syslog", no_argument, 0, 2},
2061       {"execute", required_argument, 0, 'e'},
2062       {"gid", required_argument, 0, 'g'},
2063       {"help", no_argument, 0, 'h'},
2064       {"local", required_argument, 0, 'l'},
2065       {"setkey", required_argument, 0, 'k'},
2066       {"supervised", no_argument, 0, 3},
2067       {"uid", required_argument, 0, 'u'},
2068       {"verbose", no_argument, 0, 'v'},
2069       {"version", no_argument, 0, 'V'},
2070       {0,0,0,0}
2071     };
2072     int longindex=0;
2073     string optstring;
2074     for(;;) {
2075       int c=getopt_long(argc, argv, "a:cC:e:g:hk:l:u:vV", longopts, &longindex);
2076       if(c==-1)
2077         break;
2078       switch(c) {
2079       case 1:
2080         g_cmdLine.checkConfig=true;
2081         break;
2082       case 2:
2083         g_syslog=false;
2084         break;
2085       case 3:
2086         g_cmdLine.beSupervised=true;
2087         break;
2088       case 'C':
2089         g_cmdLine.config=optarg;
2090         break;
2091       case 'c':
2092         g_cmdLine.beClient=true;
2093         break;
2094       case 'e':
2095         g_cmdLine.command=optarg;
2096         break;
2097       case 'g':
2098         g_cmdLine.gid=optarg;
2099         break;
2100       case 'h':
2101         cout<<"dnsdist "<<VERSION<<endl;
2102         usage();
2103         cout<<"\n";
2104         exit(EXIT_SUCCESS);
2105         break;
2106       case 'a':
2107         optstring=optarg;
2108         g_ACL.modify([optstring](NetmaskGroup& nmg) { nmg.addMask(optstring); });
2109         break;
2110       case 'k':
2111 #ifdef HAVE_LIBSODIUM
2112         if (B64Decode(string(optarg), g_consoleKey) < 0) {
2113           cerr<<"Unable to decode key '"<<optarg<<"'."<<endl;
2114           exit(EXIT_FAILURE);
2115         }
2116 #else
2117         cerr<<"dnsdist has been built without libsodium, -k/--setkey is unsupported."<<endl;
2118         exit(EXIT_FAILURE);
2119 #endif
2120         break;
2121       case 'l':
2122         g_cmdLine.locals.push_back(boost::trim_copy(string(optarg)));
2123         break;
2124       case 'u':
2125         g_cmdLine.uid=optarg;
2126         break;
2127       case 'v':
2128         g_verbose=true;
2129         break;
2130       case 'V':
2131 #ifdef LUAJIT_VERSION
2132         cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<" ["<<LUAJIT_VERSION<<"])"<<endl;
2133 #else
2134         cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<")"<<endl;
2135 #endif
2136         cout<<"Enabled features: ";
2137 #ifdef HAVE_CDB
2138         cout<<"cdb ";
2139 #endif
2140 #ifdef HAVE_DNS_OVER_TLS
2141         cout<<"dns-over-tls(";
2142 #ifdef HAVE_GNUTLS
2143         cout<<"gnutls";
2144 #ifdef HAVE_LIBSSL
2145         cout<<" ";
2146 #endif
2147 #endif
2148 #ifdef HAVE_LIBSSL
2149         cout<<"openssl";
2150 #endif
2151         cout<<") ";
2152 #endif
2153 #ifdef HAVE_DNS_OVER_HTTPS
2154         cout<<"dns-over-https(DOH) ";
2155 #endif
2156 #ifdef HAVE_DNSCRYPT
2157         cout<<"dnscrypt ";
2158 #endif
2159 #ifdef HAVE_EBPF
2160         cout<<"ebpf ";
2161 #endif
2162 #ifdef HAVE_FSTRM
2163         cout<<"fstrm ";
2164 #endif
2165 #ifdef HAVE_LIBCRYPTO
2166         cout<<"ipcipher ";
2167 #endif
2168 #ifdef HAVE_LIBSODIUM
2169         cout<<"libsodium ";
2170 #endif
2171 #ifdef HAVE_LMDB
2172         cout<<"lmdb ";
2173 #endif
2174         cout<<"protobuf ";
2175 #ifdef HAVE_RE2
2176         cout<<"re2 ";
2177 #endif
2178 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
2179         cout<<"recvmmsg/sendmmsg ";
2180 #endif
2181 #ifdef HAVE_NET_SNMP
2182         cout<<"snmp ";
2183 #endif
2184 #ifdef HAVE_SYSTEMD
2185         cout<<"systemd";
2186 #endif
2187         cout<<endl;
2188         exit(EXIT_SUCCESS);
2189         break;
2190       case '?':
2191         //getopt_long printed an error message.
2192         usage();
2193         exit(EXIT_FAILURE);
2194         break;
2195       }
2196     }
2197 
2198     argc -= optind;
2199     argv += optind;
2200     (void) argc;
2201 
2202     for(auto p = argv; *p; ++p) {
2203       if(g_cmdLine.beClient) {
2204         clientAddress = ComboAddress(*p, 5199);
2205       } else {
2206         g_cmdLine.remotes.push_back(*p);
2207       }
2208     }
2209 
2210     ServerPolicy leastOutstandingPol{"leastOutstanding", leastOutstanding, false};
2211 
2212     g_policy.setState(leastOutstandingPol);
2213     if(g_cmdLine.beClient || !g_cmdLine.command.empty()) {
2214       setupLua(g_lua, true, false, g_cmdLine.config);
2215       if (clientAddress != ComboAddress())
2216         g_serverControl = clientAddress;
2217       doClient(g_serverControl, g_cmdLine.command);
2218       _exit(EXIT_SUCCESS);
2219     }
2220 
2221     auto acl = g_ACL.getCopy();
2222     if(acl.empty()) {
2223       for(auto& addr : {"127.0.0.0/8", "10.0.0.0/8", "100.64.0.0/10", "169.254.0.0/16", "192.168.0.0/16", "172.16.0.0/12", "::1/128", "fc00::/7", "fe80::/10"})
2224         acl.addMask(addr);
2225       g_ACL.setState(acl);
2226     }
2227 
2228     auto consoleACL = g_consoleACL.getCopy();
2229     for (const auto& mask : { "127.0.0.1/8", "::1/128" }) {
2230       consoleACL.addMask(mask);
2231     }
2232     g_consoleACL.setState(consoleACL);
2233     registerBuiltInWebHandlers();
2234 
2235     if (g_cmdLine.checkConfig) {
2236       setupLua(g_lua, false, true, g_cmdLine.config);
2237       // No exception was thrown
2238       infolog("Configuration '%s' OK!", g_cmdLine.config);
2239       _exit(EXIT_SUCCESS);
2240     }
2241 
2242     auto todo = setupLua(g_lua, false, false, g_cmdLine.config);
2243 
2244     auto localPools = g_pools.getCopy();
2245     {
2246       bool precompute = false;
2247       if (g_policy.getLocal()->getName() == "chashed") {
2248         precompute = true;
2249       } else {
2250         for (const auto& entry: localPools) {
2251           if (entry.second->policy != nullptr && entry.second->policy->getName() == "chashed") {
2252             precompute = true;
2253             break ;
2254           }
2255         }
2256       }
2257       if (precompute) {
2258         vinfolog("Pre-computing hashes for consistent hash load-balancing policy");
2259         // pre compute hashes
2260         auto backends = g_dstates.getLocal();
2261         for (auto& backend: *backends) {
2262           if (backend->weight < 100) {
2263             vinfolog("Warning, the backend '%s' has a very low weight (%d), which will not yield a good distribution of queries with the 'chashed' policy. Please consider raising it to at least '100'.", backend->getName(), backend->weight);
2264           }
2265 
2266           backend->hash();
2267         }
2268       }
2269     }
2270 
2271     if (!g_cmdLine.locals.empty()) {
2272       for (auto it = g_frontends.begin(); it != g_frontends.end(); ) {
2273         /* DoH, DoT and DNSCrypt frontends are separate */
2274         if ((*it)->dohFrontend == nullptr && (*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr) {
2275           it = g_frontends.erase(it);
2276         }
2277         else {
2278           ++it;
2279         }
2280       }
2281 
2282       for(const auto& loc : g_cmdLine.locals) {
2283         /* UDP */
2284         g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), false, false, 0, "", {})));
2285         /* TCP */
2286         g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), true, false, 0, "", {})));
2287       }
2288     }
2289 
2290     if (g_frontends.empty()) {
2291       /* UDP */
2292       g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), false, false, 0, "", {})));
2293       /* TCP */
2294       g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), true, false, 0, "", {})));
2295     }
2296 
2297     g_configurationDone = true;
2298 
2299     for(auto& frontend : g_frontends) {
2300       setUpLocalBind(frontend);
2301 
2302       if (frontend->tcp == false) {
2303         ++udpBindsCount;
2304       }
2305       else {
2306         ++tcpBindsCount;
2307       }
2308     }
2309 
2310     warnlog("dnsdist %s comes with ABSOLUTELY NO WARRANTY. This is free software, and you are welcome to redistribute it according to the terms of the GPL version 2", VERSION);
2311 
2312     vector<string> vec;
2313     std::string acls;
2314     g_ACL.getLocal()->toStringVector(&vec);
2315     for(const auto& s : vec) {
2316       if (!acls.empty())
2317         acls += ", ";
2318       acls += s;
2319     }
2320     infolog("ACL allowing queries from: %s", acls.c_str());
2321     vec.clear();
2322     acls.clear();
2323     g_consoleACL.getLocal()->toStringVector(&vec);
2324     for (const auto& entry : vec) {
2325       if (!acls.empty()) {
2326         acls += ", ";
2327       }
2328       acls += entry;
2329     }
2330     infolog("Console ACL allowing connections from: %s", acls.c_str());
2331 
2332 #ifdef HAVE_LIBSODIUM
2333     if (g_consoleEnabled && g_consoleKey.empty()) {
2334       warnlog("Warning, the console has been enabled via 'controlSocket()' but no key has been set with 'setKey()' so all connections will fail until a key has been set");
2335     }
2336 #endif
2337 
2338     uid_t newgid=getegid();
2339     gid_t newuid=geteuid();
2340 
2341     if(!g_cmdLine.gid.empty())
2342       newgid = strToGID(g_cmdLine.gid.c_str());
2343 
2344     if(!g_cmdLine.uid.empty())
2345       newuid = strToUID(g_cmdLine.uid.c_str());
2346 
2347     if (getegid() != newgid) {
2348       if (running_in_service_mgr()) {
2349         errlog("--gid/-g set on command-line, but dnsdist was started as a systemd service. Use the 'Group' setting in the systemd unit file to set the group to run as");
2350         _exit(EXIT_FAILURE);
2351       }
2352       dropGroupPrivs(newgid);
2353     }
2354 
2355     if (geteuid() != newuid) {
2356       if (running_in_service_mgr()) {
2357         errlog("--uid/-u set on command-line, but dnsdist was started as a systemd service. Use the 'User' setting in the systemd unit file to set the user to run as");
2358         _exit(EXIT_FAILURE);
2359       }
2360       dropUserPrivs(newuid);
2361     }
2362 
2363     try {
2364       /* we might still have capabilities remaining,
2365          for example if we have been started as root
2366          without --uid or --gid (please don't do that)
2367          or as an unprivileged user with ambient
2368          capabilities like CAP_NET_BIND_SERVICE.
2369       */
2370       dropCapabilities(g_capabilitiesToRetain);
2371     }
2372     catch (const std::exception& e) {
2373       warnlog("%s", e.what());
2374     }
2375 
2376     /* this need to be done _after_ dropping privileges */
2377     g_delay = new DelayPipe<DelayedPacket>();
2378 
2379     if (g_snmpAgent) {
2380       g_snmpAgent->run();
2381     }
2382 
2383     if (!g_maxTCPClientThreads) {
2384       g_maxTCPClientThreads = std::max(tcpBindsCount, static_cast<size_t>(10));
2385     }
2386     else if (*g_maxTCPClientThreads == 0 && tcpBindsCount > 0) {
2387       warnlog("setMaxTCPClientThreads() has been set to 0 while we are accepting TCP connections, raising to 1");
2388       g_maxTCPClientThreads = 1;
2389     }
2390 
2391     g_tcpclientthreads = std::unique_ptr<TCPClientCollection>(new TCPClientCollection(*g_maxTCPClientThreads, g_useTCPSinglePipe));
2392 
2393     for (auto& t : todo) {
2394       t();
2395     }
2396 
2397     localPools = g_pools.getCopy();
2398     /* create the default pool no matter what */
2399     createPoolIfNotExists(localPools, "");
2400     if(g_cmdLine.remotes.size()) {
2401       for(const auto& address : g_cmdLine.remotes) {
2402         auto ret=std::make_shared<DownstreamState>(ComboAddress(address, 53));
2403         addServerToPool(localPools, "", ret);
2404         if (ret->connected && !ret->threadStarted.test_and_set()) {
2405           ret->tid = thread(responderThread, ret);
2406         }
2407         g_dstates.modify([ret](servers_t& servers) { servers.push_back(ret); });
2408       }
2409     }
2410     g_pools.setState(localPools);
2411 
2412     if(g_dstates.getLocal()->empty()) {
2413       errlog("No downstream servers defined: all packets will get dropped");
2414       // you might define them later, but you need to know
2415     }
2416 
2417     checkFileDescriptorsLimits(udpBindsCount, tcpBindsCount);
2418 
2419     auto mplexer = std::shared_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
2420     for(auto& dss : g_dstates.getCopy()) { // it is a copy, but the internal shared_ptrs are the real deal
2421       if (dss->availability == DownstreamState::Availability::Auto) {
2422         if (!queueHealthCheck(mplexer, dss, true)) {
2423           dss->upStatus = false;
2424           warnlog("Marking downstream %s as 'down'", dss->getNameWithAddr());
2425         }
2426       }
2427     }
2428     handleQueuedHealthChecks(mplexer, true);
2429 
2430     /* we need to create the TCP worker threads before the
2431        acceptor ones, otherwise we might crash when processing
2432        the first TCP query */
2433     while (!g_tcpclientthreads->hasReachedMaxThreads()) {
2434       g_tcpclientthreads->addTCPClientThread();
2435     }
2436 
2437     for(auto& cs : g_frontends) {
2438       if (cs->dohFrontend != nullptr) {
2439 #ifdef HAVE_DNS_OVER_HTTPS
2440         std::thread t1(dohThread, cs.get());
2441         if (!cs->cpus.empty()) {
2442           mapThreadToCPUList(t1.native_handle(), cs->cpus);
2443         }
2444         t1.detach();
2445 #endif /* HAVE_DNS_OVER_HTTPS */
2446         continue;
2447       }
2448       if (cs->udpFD >= 0) {
2449         thread t1(udpClientThread, cs.get());
2450         if (!cs->cpus.empty()) {
2451           mapThreadToCPUList(t1.native_handle(), cs->cpus);
2452         }
2453         t1.detach();
2454       }
2455       else if (cs->tcpFD >= 0) {
2456         thread t1(tcpAcceptorThread, cs.get());
2457         if (!cs->cpus.empty()) {
2458           mapThreadToCPUList(t1.native_handle(), cs->cpus);
2459         }
2460         t1.detach();
2461       }
2462     }
2463 
2464     thread carbonthread(carbonDumpThread);
2465     carbonthread.detach();
2466 
2467     thread stattid(maintThread);
2468     stattid.detach();
2469 
2470     thread healththread(healthChecksThread);
2471 
2472     thread dynBlockMaintThread(dynBlockMaintenanceThread);
2473     dynBlockMaintThread.detach();
2474 
2475     if (!g_secPollSuffix.empty()) {
2476       thread secpollthread(secPollThread);
2477       secpollthread.detach();
2478     }
2479 
2480     if(g_cmdLine.beSupervised) {
2481 #ifdef HAVE_SYSTEMD
2482       sd_notify(0, "READY=1");
2483 #endif
2484       healththread.join();
2485     }
2486     else {
2487       healththread.detach();
2488       doConsole();
2489     }
2490     _exit(EXIT_SUCCESS);
2491 
2492   }
2493   catch (const LuaContext::ExecutionErrorException& e) {
2494     try {
2495       errlog("Fatal Lua error: %s", e.what());
2496       std::rethrow_if_nested(e);
2497     } catch(const std::exception& ne) {
2498       errlog("Details: %s", ne.what());
2499     }
2500     catch (const PDNSException &ae)
2501     {
2502       errlog("Fatal pdns error: %s", ae.reason);
2503     }
2504     _exit(EXIT_FAILURE);
2505   }
2506   catch (const std::exception &e)
2507   {
2508     errlog("Fatal error: %s", e.what());
2509     _exit(EXIT_FAILURE);
2510   }
2511   catch (const PDNSException &ae)
2512   {
2513     errlog("Fatal pdns error: %s", ae.reason);
2514     _exit(EXIT_FAILURE);
2515   }
2516 }
2517 
getLatencyCount(const std::string &)2518 uint64_t getLatencyCount(const std::string&)
2519 {
2520     return g_stats.responses + g_stats.selfAnswered + g_stats.cacheHits;
2521 }
2522