1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/resolve_context.h"
6
7 #include <algorithm>
8 #include <cstdlib>
9 #include <limits>
10 #include <string>
11 #include <utility>
12
13 #include "base/logging.h"
14 #include "base/metrics/bucket_ranges.h"
15 #include "base/metrics/histogram.h"
16 #include "base/metrics/histogram_base.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/metrics/sample_vector.h"
19 #include "base/no_destructor.h"
20 #include "base/numerics/safe_conversions.h"
21 #include "base/strings/stringprintf.h"
22 #include "net/base/network_change_notifier.h"
23 #include "net/dns/dns_server_iterator.h"
24 #include "net/dns/dns_session.h"
25 #include "net/dns/dns_util.h"
26 #include "net/dns/host_cache.h"
27
28 namespace net {
29
30 namespace {
31
32 // Set min timeout, in case we are talking to a local DNS proxy.
33 const base::TimeDelta kMinTimeout = base::TimeDelta::FromMilliseconds(10);
34
35 // Default maximum timeout between queries, even with exponential backoff.
36 // (Can be overridden by field trial.)
37 const base::TimeDelta kDefaultMaxTimeout = base::TimeDelta::FromSeconds(5);
38
39 // Maximum RTT that will fit in the RTT histograms.
40 const base::TimeDelta kRttMax = base::TimeDelta::FromSeconds(30);
41 // Number of buckets in the histogram of observed RTTs.
42 const size_t kRttBucketCount = 350;
43 // Target percentile in the RTT histogram used for retransmission timeout.
44 const int kRttPercentile = 99;
45 // Number of samples to seed the histogram with.
46 const base::HistogramBase::Count kNumSeeds = 2;
47
GetDefaultTimeout(const DnsConfig & config)48 base::TimeDelta GetDefaultTimeout(const DnsConfig& config) {
49 NetworkChangeNotifier::ConnectionType type =
50 NetworkChangeNotifier::GetConnectionType();
51 return GetTimeDeltaForConnectionTypeFromFieldTrialOrDefault(
52 "AsyncDnsInitialTimeoutMsByConnectionType", config.timeout, type);
53 }
54
GetMaxTimeout()55 base::TimeDelta GetMaxTimeout() {
56 NetworkChangeNotifier::ConnectionType type =
57 NetworkChangeNotifier::GetConnectionType();
58 return GetTimeDeltaForConnectionTypeFromFieldTrialOrDefault(
59 "AsyncDnsMaxTimeoutMsByConnectionType", kDefaultMaxTimeout, type);
60 }
61
62 class RttBuckets : public base::BucketRanges {
63 public:
RttBuckets()64 RttBuckets() : base::BucketRanges(kRttBucketCount + 1) {
65 base::Histogram::InitializeBucketRanges(
66 1,
67 base::checked_cast<base::HistogramBase::Sample>(
68 kRttMax.InMilliseconds()),
69 this);
70 }
71 };
72
GetRttBuckets()73 static RttBuckets* GetRttBuckets() {
74 static base::NoDestructor<RttBuckets> buckets;
75 return buckets.get();
76 }
77
GetRttHistogram(base::TimeDelta rtt_estimate)78 static std::unique_ptr<base::SampleVector> GetRttHistogram(
79 base::TimeDelta rtt_estimate) {
80 std::unique_ptr<base::SampleVector> histogram =
81 std::make_unique<base::SampleVector>(GetRttBuckets());
82 // Seed histogram with 2 samples at |rtt_estimate| timeout.
83 histogram->Accumulate(base::checked_cast<base::HistogramBase::Sample>(
84 rtt_estimate.InMilliseconds()),
85 kNumSeeds);
86 return histogram;
87 }
88
89 } // namespace
90
ServerStats(std::unique_ptr<base::SampleVector> buckets)91 ResolveContext::ServerStats::ServerStats(
92 std::unique_ptr<base::SampleVector> buckets)
93 : last_failure_count(0), rtt_histogram(std::move(buckets)) {}
94
95 ResolveContext::ServerStats::ServerStats(ServerStats&&) = default;
96
97 ResolveContext::ServerStats::~ServerStats() = default;
98
ResolveContext(URLRequestContext * url_request_context,bool enable_caching)99 ResolveContext::ResolveContext(URLRequestContext* url_request_context,
100 bool enable_caching)
101 : url_request_context_(url_request_context),
102 host_cache_(enable_caching ? HostCache::CreateDefaultCache() : nullptr) {
103 max_timeout_ = GetMaxTimeout();
104 }
105
106 ResolveContext::~ResolveContext() = default;
107
GetDohIterator(const DnsConfig & config,const DnsConfig::SecureDnsMode & mode,const DnsSession * session)108 std::unique_ptr<DnsServerIterator> ResolveContext::GetDohIterator(
109 const DnsConfig& config,
110 const DnsConfig::SecureDnsMode& mode,
111 const DnsSession* session) {
112 // Make the iterator even if the session differs. The first call to the member
113 // functions will catch the out of date session.
114
115 std::unique_ptr<DnsServerIterator> itr(new DohDnsServerIterator(
116 doh_server_stats_.size(), FirstServerIndex(true, session),
117 config.doh_attempts, config.attempts, mode, this, session));
118 return itr;
119 }
120
GetClassicDnsIterator(const DnsConfig & config,const DnsSession * session)121 std::unique_ptr<DnsServerIterator> ResolveContext::GetClassicDnsIterator(
122 const DnsConfig& config,
123 const DnsSession* session) {
124 // Make the iterator even if the session differs. The first call to the member
125 // functions will catch the out of date session.
126
127 std::unique_ptr<DnsServerIterator> itr(new ClassicDnsServerIterator(
128 config.nameservers.size(), FirstServerIndex(false, session),
129 config.attempts, config.attempts, this, session));
130 return itr;
131 }
132
GetDohServerAvailability(size_t doh_server_index,const DnsSession * session) const133 bool ResolveContext::GetDohServerAvailability(size_t doh_server_index,
134 const DnsSession* session) const {
135 if (!IsCurrentSession(session))
136 return false;
137
138 CHECK_LT(doh_server_index, doh_server_stats_.size());
139 return ServerStatsToDohAvailability(doh_server_stats_[doh_server_index]);
140 }
141
NumAvailableDohServers(const DnsSession * session) const142 size_t ResolveContext::NumAvailableDohServers(const DnsSession* session) const {
143 if (!IsCurrentSession(session))
144 return 0;
145
146 return std::count_if(doh_server_stats_.cbegin(), doh_server_stats_.cend(),
147 &ServerStatsToDohAvailability);
148 }
149
RecordServerFailure(size_t server_index,bool is_doh_server,const DnsSession * session)150 void ResolveContext::RecordServerFailure(size_t server_index,
151 bool is_doh_server,
152 const DnsSession* session) {
153 if (!IsCurrentSession(session))
154 return;
155
156 size_t num_available_doh_servers_before = NumAvailableDohServers(session);
157
158 ServerStats* stats = GetServerStats(server_index, is_doh_server);
159 ++(stats->last_failure_count);
160 stats->last_failure = base::TimeTicks::Now();
161
162 size_t num_available_doh_servers_now = NumAvailableDohServers(session);
163 if (num_available_doh_servers_now < num_available_doh_servers_before) {
164 NotifyDohStatusObserversOfUnavailable(false /* network_change */);
165
166 // TODO(crbug.com/1022059): Consider figuring out some way to only for the
167 // first context enabling DoH or the last context disabling DoH.
168 if (num_available_doh_servers_now == 0)
169 NetworkChangeNotifier::TriggerNonSystemDnsChange();
170 }
171 }
172
RecordServerSuccess(size_t server_index,bool is_doh_server,const DnsSession * session)173 void ResolveContext::RecordServerSuccess(size_t server_index,
174 bool is_doh_server,
175 const DnsSession* session) {
176 if (!IsCurrentSession(session))
177 return;
178
179 bool doh_available_before = NumAvailableDohServers(session) > 0;
180
181 ServerStats* stats = GetServerStats(server_index, is_doh_server);
182 stats->last_failure_count = 0;
183 stats->current_connection_success = true;
184 stats->last_failure = base::TimeTicks();
185 stats->last_success = base::TimeTicks::Now();
186
187 // TODO(crbug.com/1022059): Consider figuring out some way to only for the
188 // first context enabling DoH or the last context disabling DoH.
189 bool doh_available_now = NumAvailableDohServers(session) > 0;
190 if (doh_available_before != doh_available_now)
191 NetworkChangeNotifier::TriggerNonSystemDnsChange();
192 }
193
RecordRtt(size_t server_index,bool is_doh_server,base::TimeDelta rtt,int rv,const DnsSession * session)194 void ResolveContext::RecordRtt(size_t server_index,
195 bool is_doh_server,
196 base::TimeDelta rtt,
197 int rv,
198 const DnsSession* session) {
199 if (!IsCurrentSession(session))
200 return;
201
202 RecordRttForUma(server_index, is_doh_server, rtt, rv, session);
203
204 ServerStats* stats = GetServerStats(server_index, is_doh_server);
205
206 // RTT values shouldn't be less than 0, but it shouldn't cause a crash if
207 // they are anyway, so clip to 0. See https://crbug.com/753568.
208 if (rtt < base::TimeDelta())
209 rtt = base::TimeDelta();
210
211 // Histogram-based method.
212 stats->rtt_histogram->Accumulate(
213 base::saturated_cast<base::HistogramBase::Sample>(rtt.InMilliseconds()),
214 1);
215 }
216
NextClassicTimeout(size_t classic_server_index,int attempt,const DnsSession * session)217 base::TimeDelta ResolveContext::NextClassicTimeout(size_t classic_server_index,
218 int attempt,
219 const DnsSession* session) {
220 if (!IsCurrentSession(session))
221 return std::min(GetDefaultTimeout(session->config()), max_timeout_);
222
223 return NextTimeoutHelper(
224 GetServerStats(classic_server_index, false /* is _doh_server */),
225 attempt / current_session_->config().nameservers.size());
226 }
227
NextDohTimeout(size_t doh_server_index,const DnsSession * session)228 base::TimeDelta ResolveContext::NextDohTimeout(size_t doh_server_index,
229 const DnsSession* session) {
230 if (!IsCurrentSession(session))
231 return std::min(GetDefaultTimeout(session->config()), max_timeout_);
232
233 return NextTimeoutHelper(
234 GetServerStats(doh_server_index, true /* is _doh_server */),
235 0 /* num_backoffs */);
236 }
237
RegisterDohStatusObserver(DohStatusObserver * observer)238 void ResolveContext::RegisterDohStatusObserver(DohStatusObserver* observer) {
239 DCHECK(observer);
240 doh_status_observers_.AddObserver(observer);
241 }
242
UnregisterDohStatusObserver(const DohStatusObserver * observer)243 void ResolveContext::UnregisterDohStatusObserver(
244 const DohStatusObserver* observer) {
245 DCHECK(observer);
246 doh_status_observers_.RemoveObserver(observer);
247 }
248
InvalidateCachesAndPerSessionData(const DnsSession * new_session,bool network_change)249 void ResolveContext::InvalidateCachesAndPerSessionData(
250 const DnsSession* new_session,
251 bool network_change) {
252 if (host_cache_)
253 host_cache_->Invalidate();
254
255 // DNS config is constant for any given session, so if the current session is
256 // unchanged, any per-session data is safe to keep, even if it's dependent on
257 // a specific config.
258 if (new_session && new_session == current_session_.get())
259 return;
260
261 current_session_.reset();
262 classic_server_stats_.clear();
263 doh_server_stats_.clear();
264 initial_timeout_ = base::TimeDelta();
265 max_timeout_ = GetMaxTimeout();
266
267 if (!new_session) {
268 NotifyDohStatusObserversOfSessionChanged();
269 return;
270 }
271
272 current_session_ = new_session->GetWeakPtr();
273
274 initial_timeout_ = GetDefaultTimeout(current_session_->config());
275
276 for (size_t i = 0; i < new_session->config().nameservers.size(); ++i) {
277 classic_server_stats_.emplace_back(GetRttHistogram(initial_timeout_));
278 }
279 for (size_t i = 0; i < new_session->config().dns_over_https_servers.size();
280 ++i) {
281 doh_server_stats_.emplace_back(GetRttHistogram(initial_timeout_));
282 }
283
284 CHECK_EQ(new_session->config().nameservers.size(),
285 classic_server_stats_.size());
286 CHECK_EQ(new_session->config().dns_over_https_servers.size(),
287 doh_server_stats_.size());
288
289 NotifyDohStatusObserversOfSessionChanged();
290
291 if (!doh_server_stats_.empty())
292 NotifyDohStatusObserversOfUnavailable(network_change);
293 }
294
FirstServerIndex(bool doh_server,const DnsSession * session)295 size_t ResolveContext::FirstServerIndex(bool doh_server,
296 const DnsSession* session) {
297 if (!IsCurrentSession(session))
298 return 0u;
299
300 // DoH first server doesn't rotate, so always return 0u.
301 if (doh_server)
302 return 0u;
303
304 size_t index = classic_server_index_;
305 if (current_session_->config().rotate) {
306 classic_server_index_ = (classic_server_index_ + 1) %
307 current_session_->config().nameservers.size();
308 }
309 return index;
310 }
311
IsCurrentSession(const DnsSession * session) const312 bool ResolveContext::IsCurrentSession(const DnsSession* session) const {
313 CHECK(session);
314 if (session == current_session_.get()) {
315 CHECK_EQ(current_session_->config().nameservers.size(),
316 classic_server_stats_.size());
317 CHECK_EQ(current_session_->config().dns_over_https_servers.size(),
318 doh_server_stats_.size());
319 return true;
320 }
321
322 return false;
323 }
324
GetServerStats(size_t server_index,bool is_doh_server)325 ResolveContext::ServerStats* ResolveContext::GetServerStats(
326 size_t server_index,
327 bool is_doh_server) {
328 if (!is_doh_server) {
329 CHECK_LT(server_index, classic_server_stats_.size());
330 return &classic_server_stats_[server_index];
331 } else {
332 CHECK_LT(server_index, doh_server_stats_.size());
333 return &doh_server_stats_[server_index];
334 }
335 }
336
NextTimeoutHelper(ServerStats * server_stats,int num_backoffs)337 base::TimeDelta ResolveContext::NextTimeoutHelper(ServerStats* server_stats,
338 int num_backoffs) {
339 // Respect initial timeout (from config or field trial) if it exceeds max.
340 if (initial_timeout_ > max_timeout_)
341 return initial_timeout_;
342
343 static_assert(std::numeric_limits<base::HistogramBase::Count>::is_signed,
344 "histogram base count assumed to be signed");
345
346 // Use fixed percentile of observed samples.
347 const base::SampleVector& samples = *server_stats->rtt_histogram;
348
349 base::HistogramBase::Count total = samples.TotalCount();
350 base::HistogramBase::Count remaining_count = kRttPercentile * total / 100;
351 size_t index = 0;
352 while (remaining_count > 0 && index < GetRttBuckets()->size()) {
353 remaining_count -= samples.GetCountAtIndex(index);
354 ++index;
355 }
356
357 base::TimeDelta timeout =
358 base::TimeDelta::FromMilliseconds(GetRttBuckets()->range(index));
359
360 timeout = std::max(timeout, kMinTimeout);
361
362 return std::min(timeout * (1 << num_backoffs), max_timeout_);
363 }
364
RecordRttForUma(size_t server_index,bool is_doh_server,base::TimeDelta rtt,int rv,const DnsSession * session)365 void ResolveContext::RecordRttForUma(size_t server_index,
366 bool is_doh_server,
367 base::TimeDelta rtt,
368 int rv,
369 const DnsSession* session) {
370 DCHECK(IsCurrentSession(session));
371
372 std::string query_type;
373 std::string provider_id;
374 if (is_doh_server) {
375 // Secure queries are validated if the DoH server state is available.
376 if (GetDohServerAvailability(server_index, session))
377 query_type = "SecureValidated";
378 else
379 query_type = "SecureNotValidated";
380 provider_id = GetDohProviderIdForHistogramFromDohConfig(
381 current_session_->config().dns_over_https_servers[server_index]);
382 } else {
383 query_type = "Insecure";
384 provider_id = GetDohProviderIdForHistogramFromNameserver(
385 current_session_->config().nameservers[server_index]);
386 }
387 if (rv == OK || rv == ERR_NAME_NOT_RESOLVED) {
388 base::UmaHistogramMediumTimes(
389 base::StringPrintf("Net.DNS.DnsTransaction.%s.%s.SuccessTime",
390 query_type.c_str(), provider_id.c_str()),
391 rtt);
392 } else {
393 base::UmaHistogramMediumTimes(
394 base::StringPrintf("Net.DNS.DnsTransaction.%s.%s.FailureTime",
395 query_type.c_str(), provider_id.c_str()),
396 rtt);
397 if (is_doh_server) {
398 base::UmaHistogramSparse(
399 base::StringPrintf("Net.DNS.DnsTransaction.%s.%s.FailureError",
400 query_type.c_str(), provider_id.c_str()),
401 std::abs(rv));
402 }
403 }
404 }
405
NotifyDohStatusObserversOfSessionChanged()406 void ResolveContext::NotifyDohStatusObserversOfSessionChanged() {
407 for (auto& observer : doh_status_observers_)
408 observer.OnSessionChanged();
409 }
410
NotifyDohStatusObserversOfUnavailable(bool network_change)411 void ResolveContext::NotifyDohStatusObserversOfUnavailable(
412 bool network_change) {
413 for (auto& observer : doh_status_observers_)
414 observer.OnDohServerUnavailable(network_change);
415 }
416
417 // static
ServerStatsToDohAvailability(const ResolveContext::ServerStats & stats)418 bool ResolveContext::ServerStatsToDohAvailability(
419 const ResolveContext::ServerStats& stats) {
420 return stats.last_failure_count < kAutomaticModeFailureLimit &&
421 stats.current_connection_success;
422 }
423
424 } // namespace net
425