1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/resolve_context.h"
6 
7 #include <algorithm>
8 #include <cstdlib>
9 #include <limits>
10 #include <string>
11 #include <utility>
12 
13 #include "base/logging.h"
14 #include "base/metrics/bucket_ranges.h"
15 #include "base/metrics/histogram.h"
16 #include "base/metrics/histogram_base.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/metrics/sample_vector.h"
19 #include "base/no_destructor.h"
20 #include "base/numerics/safe_conversions.h"
21 #include "base/strings/stringprintf.h"
22 #include "net/base/network_change_notifier.h"
23 #include "net/dns/dns_server_iterator.h"
24 #include "net/dns/dns_session.h"
25 #include "net/dns/dns_util.h"
26 #include "net/dns/host_cache.h"
27 
28 namespace net {
29 
30 namespace {
31 
32 // Set min timeout, in case we are talking to a local DNS proxy.
33 const base::TimeDelta kMinTimeout = base::TimeDelta::FromMilliseconds(10);
34 
35 // Default maximum timeout between queries, even with exponential backoff.
36 // (Can be overridden by field trial.)
37 const base::TimeDelta kDefaultMaxTimeout = base::TimeDelta::FromSeconds(5);
38 
39 // Maximum RTT that will fit in the RTT histograms.
40 const base::TimeDelta kRttMax = base::TimeDelta::FromSeconds(30);
41 // Number of buckets in the histogram of observed RTTs.
42 const size_t kRttBucketCount = 350;
43 // Target percentile in the RTT histogram used for retransmission timeout.
44 const int kRttPercentile = 99;
45 // Number of samples to seed the histogram with.
46 const base::HistogramBase::Count kNumSeeds = 2;
47 
GetDefaultTimeout(const DnsConfig & config)48 base::TimeDelta GetDefaultTimeout(const DnsConfig& config) {
49   NetworkChangeNotifier::ConnectionType type =
50       NetworkChangeNotifier::GetConnectionType();
51   return GetTimeDeltaForConnectionTypeFromFieldTrialOrDefault(
52       "AsyncDnsInitialTimeoutMsByConnectionType", config.timeout, type);
53 }
54 
GetMaxTimeout()55 base::TimeDelta GetMaxTimeout() {
56   NetworkChangeNotifier::ConnectionType type =
57       NetworkChangeNotifier::GetConnectionType();
58   return GetTimeDeltaForConnectionTypeFromFieldTrialOrDefault(
59       "AsyncDnsMaxTimeoutMsByConnectionType", kDefaultMaxTimeout, type);
60 }
61 
62 class RttBuckets : public base::BucketRanges {
63  public:
RttBuckets()64   RttBuckets() : base::BucketRanges(kRttBucketCount + 1) {
65     base::Histogram::InitializeBucketRanges(
66         1,
67         base::checked_cast<base::HistogramBase::Sample>(
68             kRttMax.InMilliseconds()),
69         this);
70   }
71 };
72 
GetRttBuckets()73 static RttBuckets* GetRttBuckets() {
74   static base::NoDestructor<RttBuckets> buckets;
75   return buckets.get();
76 }
77 
GetRttHistogram(base::TimeDelta rtt_estimate)78 static std::unique_ptr<base::SampleVector> GetRttHistogram(
79     base::TimeDelta rtt_estimate) {
80   std::unique_ptr<base::SampleVector> histogram =
81       std::make_unique<base::SampleVector>(GetRttBuckets());
82   // Seed histogram with 2 samples at |rtt_estimate| timeout.
83   histogram->Accumulate(base::checked_cast<base::HistogramBase::Sample>(
84                             rtt_estimate.InMilliseconds()),
85                         kNumSeeds);
86   return histogram;
87 }
88 
89 }  // namespace
90 
ServerStats(std::unique_ptr<base::SampleVector> buckets)91 ResolveContext::ServerStats::ServerStats(
92     std::unique_ptr<base::SampleVector> buckets)
93     : last_failure_count(0), rtt_histogram(std::move(buckets)) {}
94 
95 ResolveContext::ServerStats::ServerStats(ServerStats&&) = default;
96 
97 ResolveContext::ServerStats::~ServerStats() = default;
98 
ResolveContext(URLRequestContext * url_request_context,bool enable_caching)99 ResolveContext::ResolveContext(URLRequestContext* url_request_context,
100                                bool enable_caching)
101     : url_request_context_(url_request_context),
102       host_cache_(enable_caching ? HostCache::CreateDefaultCache() : nullptr) {
103   max_timeout_ = GetMaxTimeout();
104 }
105 
106 ResolveContext::~ResolveContext() = default;
107 
GetDohIterator(const DnsConfig & config,const DnsConfig::SecureDnsMode & mode,const DnsSession * session)108 std::unique_ptr<DnsServerIterator> ResolveContext::GetDohIterator(
109     const DnsConfig& config,
110     const DnsConfig::SecureDnsMode& mode,
111     const DnsSession* session) {
112   // Make the iterator even if the session differs. The first call to the member
113   // functions will catch the out of date session.
114 
115   std::unique_ptr<DnsServerIterator> itr(new DohDnsServerIterator(
116       doh_server_stats_.size(), FirstServerIndex(true, session),
117       config.doh_attempts, config.attempts, mode, this, session));
118   return itr;
119 }
120 
GetClassicDnsIterator(const DnsConfig & config,const DnsSession * session)121 std::unique_ptr<DnsServerIterator> ResolveContext::GetClassicDnsIterator(
122     const DnsConfig& config,
123     const DnsSession* session) {
124   // Make the iterator even if the session differs. The first call to the member
125   // functions will catch the out of date session.
126 
127   std::unique_ptr<DnsServerIterator> itr(new ClassicDnsServerIterator(
128       config.nameservers.size(), FirstServerIndex(false, session),
129       config.attempts, config.attempts, this, session));
130   return itr;
131 }
132 
GetDohServerAvailability(size_t doh_server_index,const DnsSession * session) const133 bool ResolveContext::GetDohServerAvailability(size_t doh_server_index,
134                                               const DnsSession* session) const {
135   if (!IsCurrentSession(session))
136     return false;
137 
138   CHECK_LT(doh_server_index, doh_server_stats_.size());
139   return ServerStatsToDohAvailability(doh_server_stats_[doh_server_index]);
140 }
141 
NumAvailableDohServers(const DnsSession * session) const142 size_t ResolveContext::NumAvailableDohServers(const DnsSession* session) const {
143   if (!IsCurrentSession(session))
144     return 0;
145 
146   return std::count_if(doh_server_stats_.cbegin(), doh_server_stats_.cend(),
147                        &ServerStatsToDohAvailability);
148 }
149 
RecordServerFailure(size_t server_index,bool is_doh_server,const DnsSession * session)150 void ResolveContext::RecordServerFailure(size_t server_index,
151                                          bool is_doh_server,
152                                          const DnsSession* session) {
153   if (!IsCurrentSession(session))
154     return;
155 
156   size_t num_available_doh_servers_before = NumAvailableDohServers(session);
157 
158   ServerStats* stats = GetServerStats(server_index, is_doh_server);
159   ++(stats->last_failure_count);
160   stats->last_failure = base::TimeTicks::Now();
161 
162   size_t num_available_doh_servers_now = NumAvailableDohServers(session);
163   if (num_available_doh_servers_now < num_available_doh_servers_before) {
164     NotifyDohStatusObserversOfUnavailable(false /* network_change */);
165 
166     // TODO(crbug.com/1022059): Consider figuring out some way to only for the
167     // first context enabling DoH or the last context disabling DoH.
168     if (num_available_doh_servers_now == 0)
169       NetworkChangeNotifier::TriggerNonSystemDnsChange();
170   }
171 }
172 
RecordServerSuccess(size_t server_index,bool is_doh_server,const DnsSession * session)173 void ResolveContext::RecordServerSuccess(size_t server_index,
174                                          bool is_doh_server,
175                                          const DnsSession* session) {
176   if (!IsCurrentSession(session))
177     return;
178 
179   bool doh_available_before = NumAvailableDohServers(session) > 0;
180 
181   ServerStats* stats = GetServerStats(server_index, is_doh_server);
182   stats->last_failure_count = 0;
183   stats->current_connection_success = true;
184   stats->last_failure = base::TimeTicks();
185   stats->last_success = base::TimeTicks::Now();
186 
187   // TODO(crbug.com/1022059): Consider figuring out some way to only for the
188   // first context enabling DoH or the last context disabling DoH.
189   bool doh_available_now = NumAvailableDohServers(session) > 0;
190   if (doh_available_before != doh_available_now)
191     NetworkChangeNotifier::TriggerNonSystemDnsChange();
192 }
193 
RecordRtt(size_t server_index,bool is_doh_server,base::TimeDelta rtt,int rv,const DnsSession * session)194 void ResolveContext::RecordRtt(size_t server_index,
195                                bool is_doh_server,
196                                base::TimeDelta rtt,
197                                int rv,
198                                const DnsSession* session) {
199   if (!IsCurrentSession(session))
200     return;
201 
202   RecordRttForUma(server_index, is_doh_server, rtt, rv, session);
203 
204   ServerStats* stats = GetServerStats(server_index, is_doh_server);
205 
206   // RTT values shouldn't be less than 0, but it shouldn't cause a crash if
207   // they are anyway, so clip to 0. See https://crbug.com/753568.
208   if (rtt < base::TimeDelta())
209     rtt = base::TimeDelta();
210 
211   // Histogram-based method.
212   stats->rtt_histogram->Accumulate(
213       base::saturated_cast<base::HistogramBase::Sample>(rtt.InMilliseconds()),
214       1);
215 }
216 
NextClassicTimeout(size_t classic_server_index,int attempt,const DnsSession * session)217 base::TimeDelta ResolveContext::NextClassicTimeout(size_t classic_server_index,
218                                                    int attempt,
219                                                    const DnsSession* session) {
220   if (!IsCurrentSession(session))
221     return std::min(GetDefaultTimeout(session->config()), max_timeout_);
222 
223   return NextTimeoutHelper(
224       GetServerStats(classic_server_index, false /* is _doh_server */),
225       attempt / current_session_->config().nameservers.size());
226 }
227 
NextDohTimeout(size_t doh_server_index,const DnsSession * session)228 base::TimeDelta ResolveContext::NextDohTimeout(size_t doh_server_index,
229                                                const DnsSession* session) {
230   if (!IsCurrentSession(session))
231     return std::min(GetDefaultTimeout(session->config()), max_timeout_);
232 
233   return NextTimeoutHelper(
234       GetServerStats(doh_server_index, true /* is _doh_server */),
235       0 /* num_backoffs */);
236 }
237 
RegisterDohStatusObserver(DohStatusObserver * observer)238 void ResolveContext::RegisterDohStatusObserver(DohStatusObserver* observer) {
239   DCHECK(observer);
240   doh_status_observers_.AddObserver(observer);
241 }
242 
UnregisterDohStatusObserver(const DohStatusObserver * observer)243 void ResolveContext::UnregisterDohStatusObserver(
244     const DohStatusObserver* observer) {
245   DCHECK(observer);
246   doh_status_observers_.RemoveObserver(observer);
247 }
248 
InvalidateCachesAndPerSessionData(const DnsSession * new_session,bool network_change)249 void ResolveContext::InvalidateCachesAndPerSessionData(
250     const DnsSession* new_session,
251     bool network_change) {
252   if (host_cache_)
253     host_cache_->Invalidate();
254 
255   // DNS config is constant for any given session, so if the current session is
256   // unchanged, any per-session data is safe to keep, even if it's dependent on
257   // a specific config.
258   if (new_session && new_session == current_session_.get())
259     return;
260 
261   current_session_.reset();
262   classic_server_stats_.clear();
263   doh_server_stats_.clear();
264   initial_timeout_ = base::TimeDelta();
265   max_timeout_ = GetMaxTimeout();
266 
267   if (!new_session) {
268     NotifyDohStatusObserversOfSessionChanged();
269     return;
270   }
271 
272   current_session_ = new_session->GetWeakPtr();
273 
274   initial_timeout_ = GetDefaultTimeout(current_session_->config());
275 
276   for (size_t i = 0; i < new_session->config().nameservers.size(); ++i) {
277     classic_server_stats_.emplace_back(GetRttHistogram(initial_timeout_));
278   }
279   for (size_t i = 0; i < new_session->config().dns_over_https_servers.size();
280        ++i) {
281     doh_server_stats_.emplace_back(GetRttHistogram(initial_timeout_));
282   }
283 
284   CHECK_EQ(new_session->config().nameservers.size(),
285            classic_server_stats_.size());
286   CHECK_EQ(new_session->config().dns_over_https_servers.size(),
287            doh_server_stats_.size());
288 
289   NotifyDohStatusObserversOfSessionChanged();
290 
291   if (!doh_server_stats_.empty())
292     NotifyDohStatusObserversOfUnavailable(network_change);
293 }
294 
FirstServerIndex(bool doh_server,const DnsSession * session)295 size_t ResolveContext::FirstServerIndex(bool doh_server,
296                                         const DnsSession* session) {
297   if (!IsCurrentSession(session))
298     return 0u;
299 
300   // DoH first server doesn't rotate, so always return 0u.
301   if (doh_server)
302     return 0u;
303 
304   size_t index = classic_server_index_;
305   if (current_session_->config().rotate) {
306     classic_server_index_ = (classic_server_index_ + 1) %
307                             current_session_->config().nameservers.size();
308   }
309   return index;
310 }
311 
IsCurrentSession(const DnsSession * session) const312 bool ResolveContext::IsCurrentSession(const DnsSession* session) const {
313   CHECK(session);
314   if (session == current_session_.get()) {
315     CHECK_EQ(current_session_->config().nameservers.size(),
316              classic_server_stats_.size());
317     CHECK_EQ(current_session_->config().dns_over_https_servers.size(),
318              doh_server_stats_.size());
319     return true;
320   }
321 
322   return false;
323 }
324 
GetServerStats(size_t server_index,bool is_doh_server)325 ResolveContext::ServerStats* ResolveContext::GetServerStats(
326     size_t server_index,
327     bool is_doh_server) {
328   if (!is_doh_server) {
329     CHECK_LT(server_index, classic_server_stats_.size());
330     return &classic_server_stats_[server_index];
331   } else {
332     CHECK_LT(server_index, doh_server_stats_.size());
333     return &doh_server_stats_[server_index];
334   }
335 }
336 
NextTimeoutHelper(ServerStats * server_stats,int num_backoffs)337 base::TimeDelta ResolveContext::NextTimeoutHelper(ServerStats* server_stats,
338                                                   int num_backoffs) {
339   // Respect initial timeout (from config or field trial) if it exceeds max.
340   if (initial_timeout_ > max_timeout_)
341     return initial_timeout_;
342 
343   static_assert(std::numeric_limits<base::HistogramBase::Count>::is_signed,
344                 "histogram base count assumed to be signed");
345 
346   // Use fixed percentile of observed samples.
347   const base::SampleVector& samples = *server_stats->rtt_histogram;
348 
349   base::HistogramBase::Count total = samples.TotalCount();
350   base::HistogramBase::Count remaining_count = kRttPercentile * total / 100;
351   size_t index = 0;
352   while (remaining_count > 0 && index < GetRttBuckets()->size()) {
353     remaining_count -= samples.GetCountAtIndex(index);
354     ++index;
355   }
356 
357   base::TimeDelta timeout =
358       base::TimeDelta::FromMilliseconds(GetRttBuckets()->range(index));
359 
360   timeout = std::max(timeout, kMinTimeout);
361 
362   return std::min(timeout * (1 << num_backoffs), max_timeout_);
363 }
364 
RecordRttForUma(size_t server_index,bool is_doh_server,base::TimeDelta rtt,int rv,const DnsSession * session)365 void ResolveContext::RecordRttForUma(size_t server_index,
366                                      bool is_doh_server,
367                                      base::TimeDelta rtt,
368                                      int rv,
369                                      const DnsSession* session) {
370   DCHECK(IsCurrentSession(session));
371 
372   std::string query_type;
373   std::string provider_id;
374   if (is_doh_server) {
375     // Secure queries are validated if the DoH server state is available.
376     if (GetDohServerAvailability(server_index, session))
377       query_type = "SecureValidated";
378     else
379       query_type = "SecureNotValidated";
380     provider_id = GetDohProviderIdForHistogramFromDohConfig(
381         current_session_->config().dns_over_https_servers[server_index]);
382   } else {
383     query_type = "Insecure";
384     provider_id = GetDohProviderIdForHistogramFromNameserver(
385         current_session_->config().nameservers[server_index]);
386   }
387   if (rv == OK || rv == ERR_NAME_NOT_RESOLVED) {
388     base::UmaHistogramMediumTimes(
389         base::StringPrintf("Net.DNS.DnsTransaction.%s.%s.SuccessTime",
390                            query_type.c_str(), provider_id.c_str()),
391         rtt);
392   } else {
393     base::UmaHistogramMediumTimes(
394         base::StringPrintf("Net.DNS.DnsTransaction.%s.%s.FailureTime",
395                            query_type.c_str(), provider_id.c_str()),
396         rtt);
397     if (is_doh_server) {
398       base::UmaHistogramSparse(
399           base::StringPrintf("Net.DNS.DnsTransaction.%s.%s.FailureError",
400                              query_type.c_str(), provider_id.c_str()),
401           std::abs(rv));
402     }
403   }
404 }
405 
NotifyDohStatusObserversOfSessionChanged()406 void ResolveContext::NotifyDohStatusObserversOfSessionChanged() {
407   for (auto& observer : doh_status_observers_)
408     observer.OnSessionChanged();
409 }
410 
NotifyDohStatusObserversOfUnavailable(bool network_change)411 void ResolveContext::NotifyDohStatusObserversOfUnavailable(
412     bool network_change) {
413   for (auto& observer : doh_status_observers_)
414     observer.OnDohServerUnavailable(network_change);
415 }
416 
417 // static
ServerStatsToDohAvailability(const ResolveContext::ServerStats & stats)418 bool ResolveContext::ServerStatsToDohAvailability(
419     const ResolveContext::ServerStats& stats) {
420   return stats.last_failure_count < kAutomaticModeFailureLimit &&
421          stats.current_connection_success;
422 }
423 
424 }  // namespace net
425