1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "GetAddrInfo.h"
8 
9 #ifdef DNSQUERY_AVAILABLE
10 // There is a bug in windns.h where the type of parameter ppQueryResultsSet for
11 // DnsQuery_A is dependent on UNICODE being set. It should *always* be
12 // PDNS_RECORDA, but if UNICODE is set it is PDNS_RECORDW. To get around this
13 // we make sure that UNICODE is unset.
14 #  undef UNICODE
15 #  include <ws2tcpip.h>
16 #  undef GetAddrInfo
17 #  include <windns.h>
18 #endif  // DNSQUERY_AVAILABLE
19 
20 #include "mozilla/ClearOnShutdown.h"
21 #include "mozilla/net/DNS.h"
22 #include "NativeDNSResolverOverrideParent.h"
23 #include "prnetdb.h"
24 #include "nsIOService.h"
25 #include "nsHostResolver.h"
26 #include "nsError.h"
27 #include "mozilla/net/DNS.h"
28 #include <algorithm>
29 #include "prerror.h"
30 
31 #include "mozilla/Logging.h"
32 #include "mozilla/StaticPrefs_network.h"
33 
34 namespace mozilla::net {
35 
36 static StaticRefPtr<NativeDNSResolverOverride> gOverrideService;
37 
38 static LazyLogModule gGetAddrInfoLog("GetAddrInfo");
39 #define LOG(msg, ...) \
40   MOZ_LOG(gGetAddrInfoLog, LogLevel::Debug, ("[DNS]: " msg, ##__VA_ARGS__))
41 #define LOG_WARNING(msg, ...) \
42   MOZ_LOG(gGetAddrInfoLog, LogLevel::Warning, ("[DNS]: " msg, ##__VA_ARGS__))
43 
44 #ifdef DNSQUERY_AVAILABLE
45 
46 #  define COMPUTER_NAME_BUFFER_SIZE 100
47 static char sDNSComputerName[COMPUTER_NAME_BUFFER_SIZE];
48 static char sNETBIOSComputerName[MAX_COMPUTERNAME_LENGTH + 1];
49 
50 ////////////////////////////
51 // WINDOWS IMPLEMENTATION //
52 ////////////////////////////
53 
54 // Ensure consistency of PR_* and AF_* constants to allow for legacy usage of
55 // PR_* constants with this API.
56 static_assert(PR_AF_INET == AF_INET && PR_AF_INET6 == AF_INET6 &&
57                   PR_AF_UNSPEC == AF_UNSPEC,
58               "PR_AF_* must match AF_*");
59 
60 // If successful, returns in aResult a TTL value that is smaller or
61 // equal with the one already there. Gets the TTL value by calling
62 // to DnsQuery_A and iterating through the returned
63 // records to find the one with the smallest TTL value.
_CallDnsQuery_A_Windows(const nsACString & aHost,uint16_t aAddressFamily,DWORD aFlags,std::function<void (PDNS_RECORDA)> aCallback)64 static MOZ_ALWAYS_INLINE nsresult _CallDnsQuery_A_Windows(
65     const nsACString& aHost, uint16_t aAddressFamily, DWORD aFlags,
66     std::function<void(PDNS_RECORDA)> aCallback) {
67   NS_ConvertASCIItoUTF16 name(aHost);
68 
69   auto callDnsQuery_A = [&](uint16_t reqFamily) {
70     PDNS_RECORDA dnsData = nullptr;
71     DNS_STATUS status = DnsQuery_A(aHost.BeginReading(), reqFamily, aFlags,
72                                    nullptr, &dnsData, nullptr);
73     if (status == DNS_INFO_NO_RECORDS || status == DNS_ERROR_RCODE_NAME_ERROR ||
74         !dnsData) {
75       LOG("No DNS records found for %s. status=%X. reqFamily = %X\n",
76           aHost.BeginReading(), status, reqFamily);
77       return NS_ERROR_FAILURE;
78     } else if (status != NOERROR) {
79       LOG_WARNING("DnsQuery_A failed with status %X.\n", status);
80       return NS_ERROR_UNEXPECTED;
81     }
82 
83     for (PDNS_RECORDA curRecord = dnsData; curRecord;
84          curRecord = curRecord->pNext) {
85       // Only records in the answer section are important
86       if (curRecord->Flags.S.Section != DnsSectionAnswer) {
87         continue;
88       }
89       if (curRecord->wType != reqFamily) {
90         continue;
91       }
92 
93       aCallback(curRecord);
94     }
95 
96     DnsFree(dnsData, DNS_FREE_TYPE::DnsFreeRecordList);
97     return NS_OK;
98   };
99 
100   if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET) {
101     callDnsQuery_A(DNS_TYPE_A);
102   }
103 
104   if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET6) {
105     callDnsQuery_A(DNS_TYPE_AAAA);
106   }
107   return NS_OK;
108 }
109 
recordTypeMatchesRequest(uint16_t wType,uint16_t aAddressFamily)110 bool recordTypeMatchesRequest(uint16_t wType, uint16_t aAddressFamily) {
111   if (aAddressFamily == PR_AF_UNSPEC) {
112     return wType == DNS_TYPE_A || wType == DNS_TYPE_AAAA;
113   }
114   if (aAddressFamily == PR_AF_INET) {
115     return wType == DNS_TYPE_A;
116   }
117   if (aAddressFamily == PR_AF_INET6) {
118     return wType == DNS_TYPE_AAAA;
119   }
120   return false;
121 }
122 
_GetTTLData_Windows(const nsACString & aHost,uint32_t * aResult,uint16_t aAddressFamily)123 static MOZ_ALWAYS_INLINE nsresult _GetTTLData_Windows(const nsACString& aHost,
124                                                       uint32_t* aResult,
125                                                       uint16_t aAddressFamily) {
126   MOZ_ASSERT(!aHost.IsEmpty());
127   MOZ_ASSERT(aResult);
128   if (aAddressFamily != PR_AF_UNSPEC && aAddressFamily != PR_AF_INET &&
129       aAddressFamily != PR_AF_INET6) {
130     return NS_ERROR_UNEXPECTED;
131   }
132 
133   // In order to avoid using ANY records which are not always implemented as a
134   // "Gimme what you have" request in hostname resolvers, we should send A
135   // and/or AAAA requests, based on the address family requested.
136   const DWORD ttlFlags =
137       (DNS_QUERY_STANDARD | DNS_QUERY_NO_NETBT | DNS_QUERY_NO_HOSTS_FILE |
138        DNS_QUERY_NO_MULTICAST | DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE |
139        DNS_QUERY_DONT_RESET_TTL_VALUES);
140   unsigned int ttl = (unsigned int)-1;
141   _CallDnsQuery_A_Windows(
142       aHost, aAddressFamily, ttlFlags,
143       [&ttl, &aHost, aAddressFamily](PDNS_RECORDA curRecord) {
144         if (recordTypeMatchesRequest(curRecord->wType, aAddressFamily)) {
145           ttl = std::min<unsigned int>(ttl, curRecord->dwTtl);
146         } else {
147           LOG("Received unexpected record type %u in response for %s.\n",
148               curRecord->wType, aHost.BeginReading());
149         }
150       });
151 
152   if (ttl == (unsigned int)-1) {
153     LOG("No useable TTL found.");
154     return NS_ERROR_FAILURE;
155   }
156 
157   *aResult = ttl;
158   return NS_OK;
159 }
160 
161 static MOZ_ALWAYS_INLINE nsresult
_DNSQuery_A_SingleLabel(const nsACString & aCanonHost,uint16_t aAddressFamily,uint16_t aFlags,AddrInfo ** aAddrInfo)162 _DNSQuery_A_SingleLabel(const nsACString& aCanonHost, uint16_t aAddressFamily,
163                         uint16_t aFlags, AddrInfo** aAddrInfo) {
164   bool setCanonName = aFlags & nsHostResolver::RES_CANON_NAME;
165   nsAutoCString canonName;
166   const DWORD flags = (DNS_QUERY_STANDARD | DNS_QUERY_NO_MULTICAST |
167                        DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE);
168   nsTArray<NetAddr> addresses;
169 
170   _CallDnsQuery_A_Windows(
171       aCanonHost, aAddressFamily, flags, [&](PDNS_RECORDA curRecord) {
172         MOZ_DIAGNOSTIC_ASSERT(curRecord->wType == DNS_TYPE_A ||
173                               curRecord->wType == DNS_TYPE_AAAA);
174         if (setCanonName) {
175           canonName.Assign(curRecord->pName);
176         }
177         NetAddr addr{};
178         addr.inet.family = AF_INET;
179         addr.inet.ip = curRecord->Data.A.IpAddress;
180         addresses.AppendElement(addr);
181       });
182 
183   LOG("Query for: %s has %u results", aCanonHost.BeginReading(),
184       addresses.Length());
185   if (addresses.IsEmpty()) {
186     return NS_ERROR_UNKNOWN_HOST;
187   }
188   RefPtr<AddrInfo> ai(new AddrInfo(
189       aCanonHost, canonName, DNSResolverType::Native, 0, std::move(addresses)));
190   ai.forget(aAddrInfo);
191 
192   return NS_OK;
193 }
194 
195 #endif
196 
197 ////////////////////////////////////
198 // PORTABLE RUNTIME IMPLEMENTATION//
199 ////////////////////////////////////
200 
201 static MOZ_ALWAYS_INLINE nsresult
_GetAddrInfo_Portable(const nsACString & aCanonHost,uint16_t aAddressFamily,uint16_t aFlags,AddrInfo ** aAddrInfo)202 _GetAddrInfo_Portable(const nsACString& aCanonHost, uint16_t aAddressFamily,
203                       uint16_t aFlags, AddrInfo** aAddrInfo) {
204   MOZ_ASSERT(!aCanonHost.IsEmpty());
205   MOZ_ASSERT(aAddrInfo);
206 
207   // We accept the same aFlags that nsHostResolver::ResolveHost accepts, but we
208   // need to translate the aFlags into a form that PR_GetAddrInfoByName
209   // accepts.
210   int prFlags = PR_AI_ADDRCONFIG;
211   if (!(aFlags & nsHostResolver::RES_CANON_NAME)) {
212     prFlags |= PR_AI_NOCANONNAME;
213   }
214 
215   // We need to remove IPv4 records manually because PR_GetAddrInfoByName
216   // doesn't support PR_AF_INET6.
217   bool disableIPv4 = aAddressFamily == PR_AF_INET6;
218   if (disableIPv4) {
219     aAddressFamily = PR_AF_UNSPEC;
220   }
221 
222 #if defined(DNSQUERY_AVAILABLE)
223   if (StaticPrefs::network_dns_dns_query_single_label() &&
224       !aCanonHost.Contains('.') && aCanonHost != "localhost"_ns) {
225     // For some reason we can't use DnsQuery_A to get the computer's IP.
226     if (!aCanonHost.Equals(nsDependentCString(sDNSComputerName),
227                            nsCaseInsensitiveCStringComparator) &&
228         !aCanonHost.Equals(nsDependentCString(sNETBIOSComputerName),
229                            nsCaseInsensitiveCStringComparator)) {
230       // This is a single label name resolve without a dot.
231       // We use DNSQuery_A for these.
232       LOG("Resolving %s using DnsQuery_A (computername: %s)\n",
233           aCanonHost.BeginReading(), sDNSComputerName);
234       return _DNSQuery_A_SingleLabel(aCanonHost, aAddressFamily, aFlags,
235                                      aAddrInfo);
236     }
237   }
238 #endif
239 
240   LOG("Resolving %s using PR_GetAddrInfoByName", aCanonHost.BeginReading());
241   PRAddrInfo* prai =
242       PR_GetAddrInfoByName(aCanonHost.BeginReading(), aAddressFamily, prFlags);
243 
244   if (!prai) {
245     LOG("PR_GetAddrInfoByName returned null PR_GetError:%d PR_GetOSErrpr:%d",
246         PR_GetError(), PR_GetOSError());
247     return NS_ERROR_UNKNOWN_HOST;
248   }
249 
250   nsAutoCString canonName;
251   if (aFlags & nsHostResolver::RES_CANON_NAME) {
252     canonName.Assign(PR_GetCanonNameFromAddrInfo(prai));
253   }
254 
255   bool filterNameCollision =
256       !(aFlags & nsHostResolver::RES_ALLOW_NAME_COLLISION);
257   RefPtr<AddrInfo> ai(new AddrInfo(aCanonHost, prai, disableIPv4,
258                                    filterNameCollision, canonName));
259   PR_FreeAddrInfo(prai);
260   if (ai->Addresses().IsEmpty()) {
261     LOG("PR_GetAddrInfoByName returned empty address list");
262     return NS_ERROR_UNKNOWN_HOST;
263   }
264 
265   ai.forget(aAddrInfo);
266 
267   LOG("PR_GetAddrInfoByName resolved successfully");
268   return NS_OK;
269 }
270 
271 //////////////////////////////////////
272 // COMMON/PLATFORM INDEPENDENT CODE //
273 //////////////////////////////////////
GetAddrInfoInit()274 nsresult GetAddrInfoInit() {
275   LOG("Initializing GetAddrInfo.\n");
276 
277 #ifdef DNSQUERY_AVAILABLE
278   DWORD namesize = COMPUTER_NAME_BUFFER_SIZE;
279   if (!GetComputerNameExA(ComputerNameDnsHostname, sDNSComputerName,
280                           &namesize)) {
281     sDNSComputerName[0] = 0;
282   }
283   namesize = MAX_COMPUTERNAME_LENGTH + 1;
284   if (!GetComputerNameExA(ComputerNameNetBIOS, sNETBIOSComputerName,
285                           &namesize)) {
286     sNETBIOSComputerName[0] = 0;
287   }
288 #endif
289   return NS_OK;
290 }
291 
GetAddrInfoShutdown()292 nsresult GetAddrInfoShutdown() {
293   LOG("Shutting down GetAddrInfo.\n");
294   return NS_OK;
295 }
296 
FindAddrOverride(const nsACString & aHost,uint16_t aAddressFamily,uint16_t aFlags,AddrInfo ** aAddrInfo)297 bool FindAddrOverride(const nsACString& aHost, uint16_t aAddressFamily,
298                       uint16_t aFlags, AddrInfo** aAddrInfo) {
299   RefPtr<NativeDNSResolverOverride> overrideService = gOverrideService;
300   if (!overrideService) {
301     return false;
302   }
303   AutoReadLock lock(overrideService->mLock);
304   auto overrides = overrideService->mOverrides.Lookup(aHost);
305   if (!overrides) {
306     return false;
307   }
308   nsCString* cname = nullptr;
309   if (aFlags & nsHostResolver::RES_CANON_NAME) {
310     cname = overrideService->mCnames.Lookup(aHost).DataPtrOrNull();
311   }
312 
313   RefPtr<AddrInfo> ai;
314 
315   nsTArray<NetAddr> addresses;
316   for (const auto& ip : *overrides) {
317     if (aAddressFamily != AF_UNSPEC && ip.raw.family != aAddressFamily) {
318       continue;
319     }
320     addresses.AppendElement(ip);
321   }
322 
323   if (!cname) {
324     ai = new AddrInfo(aHost, DNSResolverType::Native, 0, std::move(addresses));
325   } else {
326     ai = new AddrInfo(aHost, *cname, DNSResolverType::Native, 0,
327                       std::move(addresses));
328   }
329 
330   ai.forget(aAddrInfo);
331   return true;
332 }
333 
GetAddrInfo(const nsACString & aHost,uint16_t aAddressFamily,uint16_t aFlags,AddrInfo ** aAddrInfo,bool aGetTtl)334 nsresult GetAddrInfo(const nsACString& aHost, uint16_t aAddressFamily,
335                      uint16_t aFlags, AddrInfo** aAddrInfo, bool aGetTtl) {
336   if (NS_WARN_IF(aHost.IsEmpty()) || NS_WARN_IF(!aAddrInfo)) {
337     return NS_ERROR_NULL_POINTER;
338   }
339   *aAddrInfo = nullptr;
340 
341   if (StaticPrefs::network_dns_disabled()) {
342     return NS_ERROR_UNKNOWN_HOST;
343   }
344 
345 #ifdef DNSQUERY_AVAILABLE
346   // The GetTTLData needs the canonical name to function properly
347   if (aGetTtl) {
348     aFlags |= nsHostResolver::RES_CANON_NAME;
349   }
350 #endif
351 
352   // If there is an override for this host, then we synthetize a result.
353   if (gOverrideService &&
354       FindAddrOverride(aHost, aAddressFamily, aFlags, aAddrInfo)) {
355     LOG("Returning IP address from NativeDNSResolverOverride");
356     return (*aAddrInfo)->Addresses().Length() ? NS_OK : NS_ERROR_UNKNOWN_HOST;
357   }
358 
359   nsAutoCString host;
360   if (StaticPrefs::network_dns_copy_string_before_call()) {
361     host = Substring(aHost.BeginReading(), aHost.Length());
362     MOZ_ASSERT(aHost.BeginReading() != host.BeginReading());
363   } else {
364     host = aHost;
365   }
366 
367   if (gNativeIsLocalhost) {
368     // pretend we use the given host but use IPv4 localhost instead!
369     host = "localhost"_ns;
370     aAddressFamily = PR_AF_INET;
371   }
372 
373   RefPtr<AddrInfo> info;
374   nsresult rv =
375       _GetAddrInfo_Portable(host, aAddressFamily, aFlags, getter_AddRefs(info));
376 
377 #ifdef DNSQUERY_AVAILABLE
378   if (aGetTtl && NS_SUCCEEDED(rv)) {
379     // Figure out the canonical name, or if that fails, just use the host name
380     // we have.
381     nsAutoCString name;
382     if (info && !info->CanonicalHostname().IsEmpty()) {
383       name = info->CanonicalHostname();
384     } else {
385       name = host;
386     }
387 
388     LOG("Getting TTL for %s (cname = %s).", host.get(), name.get());
389     uint32_t ttl = 0;
390     nsresult ttlRv = _GetTTLData_Windows(name, &ttl, aAddressFamily);
391     if (NS_SUCCEEDED(ttlRv)) {
392       auto builder = info->Build();
393       builder.SetTTL(ttl);
394       info = builder.Finish();
395       LOG("Got TTL %u for %s (name = %s).", ttl, host.get(), name.get());
396     } else {
397       LOG_WARNING("Could not get TTL for %s (cname = %s).", host.get(),
398                   name.get());
399     }
400   }
401 #endif
402 
403   info.forget(aAddrInfo);
404   return rv;
405 }
406 
407 // static
408 already_AddRefed<nsINativeDNSResolverOverride>
GetSingleton()409 NativeDNSResolverOverride::GetSingleton() {
410   if (nsIOService::UseSocketProcess() && XRE_IsParentProcess()) {
411     return NativeDNSResolverOverrideParent::GetSingleton();
412   }
413 
414   if (gOverrideService) {
415     return do_AddRef(gOverrideService);
416   }
417 
418   gOverrideService = new NativeDNSResolverOverride();
419   ClearOnShutdown(&gOverrideService);
420   return do_AddRef(gOverrideService);
421 }
422 
NS_IMPL_ISUPPORTS(NativeDNSResolverOverride,nsINativeDNSResolverOverride)423 NS_IMPL_ISUPPORTS(NativeDNSResolverOverride, nsINativeDNSResolverOverride)
424 
425 NS_IMETHODIMP NativeDNSResolverOverride::AddIPOverride(
426     const nsACString& aHost, const nsACString& aIPLiteral) {
427   NetAddr tempAddr;
428 
429   if (aIPLiteral.Equals("N/A"_ns)) {
430     AutoWriteLock lock(mLock);
431     auto& overrides = mOverrides.LookupOrInsert(aHost);
432     overrides.Clear();
433     return NS_OK;
434   }
435 
436   if (NS_FAILED(tempAddr.InitFromString(aIPLiteral))) {
437     return NS_ERROR_UNEXPECTED;
438   }
439 
440   AutoWriteLock lock(mLock);
441   auto& overrides = mOverrides.LookupOrInsert(aHost);
442   overrides.AppendElement(tempAddr);
443 
444   return NS_OK;
445 }
446 
SetCnameOverride(const nsACString & aHost,const nsACString & aCNAME)447 NS_IMETHODIMP NativeDNSResolverOverride::SetCnameOverride(
448     const nsACString& aHost, const nsACString& aCNAME) {
449   if (aCNAME.IsEmpty()) {
450     return NS_ERROR_UNEXPECTED;
451   }
452 
453   AutoWriteLock lock(mLock);
454   mCnames.InsertOrUpdate(aHost, nsCString(aCNAME));
455 
456   return NS_OK;
457 }
458 
ClearHostOverride(const nsACString & aHost)459 NS_IMETHODIMP NativeDNSResolverOverride::ClearHostOverride(
460     const nsACString& aHost) {
461   AutoWriteLock lock(mLock);
462   mCnames.Remove(aHost);
463   auto overrides = mOverrides.Extract(aHost);
464   if (!overrides) {
465     return NS_OK;
466   }
467 
468   overrides->Clear();
469   return NS_OK;
470 }
471 
ClearOverrides()472 NS_IMETHODIMP NativeDNSResolverOverride::ClearOverrides() {
473   AutoWriteLock lock(mLock);
474   mOverrides.Clear();
475   mCnames.Clear();
476   return NS_OK;
477 }
478 
479 }  // namespace mozilla::net
480