1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_trackers.h"
6 
7 #include <array>
8 #include <limits>
9 #include <utility>
10 
11 #include "discovery/common/config.h"
12 #include "discovery/mdns/mdns_random.h"
13 #include "discovery/mdns/mdns_record_changed_callback.h"
14 #include "discovery/mdns/mdns_sender.h"
15 #include "util/std_util.h"
16 
17 namespace openscreen {
18 namespace discovery {
19 
20 namespace {
21 
22 // RFC 6762 Section 5.2
23 // https://tools.ietf.org/html/rfc6762#section-5.2
24 
25 // Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
26 constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
27 
28 // Intervals between successive queries must increase by at least a factor of 2.
29 constexpr int kIntervalIncreaseFactor = 2;
30 
31 // The interval between the first two queries must be at least one second.
32 constexpr std::chrono::seconds kMinimumQueryInterval{1};
33 
34 // The querier may cap the question refresh interval to a maximum of 60 minutes.
35 constexpr std::chrono::minutes kMaximumQueryInterval{60};
36 
37 // RFC 6762 Section 10.1
38 // https://tools.ietf.org/html/rfc6762#section-10.1
39 
40 // A goodbye record is a record with TTL of 0.
IsGoodbyeRecord(const MdnsRecord & record)41 bool IsGoodbyeRecord(const MdnsRecord& record) {
42   return record.ttl() == std::chrono::seconds(0);
43 }
44 
IsNegativeResponseForType(const MdnsRecord & record,DnsType dns_type)45 bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) {
46   if (record.dns_type() != DnsType::kNSEC) {
47     return false;
48   }
49 
50   const auto& nsec_types = absl::get<NsecRecordRdata>(record.rdata()).types();
51   return std::find_if(nsec_types.begin(), nsec_types.end(),
52                       [dns_type](DnsType type) {
53                         return type == dns_type || type == DnsType::kANY;
54                       }) != nsec_types.end();
55 }
56 
57 // RFC 6762 Section 10.1
58 // https://tools.ietf.org/html/rfc6762#section-10.1
59 // In case of a goodbye record, the querier should set TTL to 1 second
60 constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
61 
62 }  // namespace
63 
MdnsTracker(MdnsSender * sender,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,TrackerType tracker_type)64 MdnsTracker::MdnsTracker(MdnsSender* sender,
65                          TaskRunner* task_runner,
66                          ClockNowFunctionPtr now_function,
67                          MdnsRandom* random_delay,
68                          TrackerType tracker_type)
69     : sender_(sender),
70       task_runner_(task_runner),
71       now_function_(now_function),
72       send_alarm_(now_function, task_runner),
73       random_delay_(random_delay),
74       tracker_type_(tracker_type) {
75   OSP_DCHECK(task_runner_);
76   OSP_DCHECK(now_function_);
77   OSP_DCHECK(random_delay_);
78   OSP_DCHECK(sender_);
79 }
80 
~MdnsTracker()81 MdnsTracker::~MdnsTracker() {
82   send_alarm_.Cancel();
83 
84   for (const MdnsTracker* node : adjacent_nodes_) {
85     node->RemovedReverseAdjacency(this);
86   }
87 }
88 
AddAdjacentNode(const MdnsTracker * node) const89 bool MdnsTracker::AddAdjacentNode(const MdnsTracker* node) const {
90   OSP_DCHECK(node);
91   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
92 
93   auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
94   if (it != adjacent_nodes_.end()) {
95     return false;
96   }
97 
98   adjacent_nodes_.push_back(node);
99   node->AddReverseAdjacency(this);
100   return true;
101 }
102 
RemoveAdjacentNode(const MdnsTracker * node) const103 bool MdnsTracker::RemoveAdjacentNode(const MdnsTracker* node) const {
104   OSP_DCHECK(node);
105   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
106 
107   auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
108   if (it == adjacent_nodes_.end()) {
109     return false;
110   }
111 
112   adjacent_nodes_.erase(it);
113   node->RemovedReverseAdjacency(this);
114   return true;
115 }
116 
AddReverseAdjacency(const MdnsTracker * node) const117 void MdnsTracker::AddReverseAdjacency(const MdnsTracker* node) const {
118   OSP_DCHECK(std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node) ==
119              adjacent_nodes_.end());
120 
121   adjacent_nodes_.push_back(node);
122 }
123 
RemovedReverseAdjacency(const MdnsTracker * node) const124 void MdnsTracker::RemovedReverseAdjacency(const MdnsTracker* node) const {
125   auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
126   OSP_DCHECK(it != adjacent_nodes_.end());
127 
128   adjacent_nodes_.erase(it);
129 }
130 
MdnsRecordTracker(MdnsRecord record,DnsType dns_type,MdnsSender * sender,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,RecordExpiredCallback record_expired_callback)131 MdnsRecordTracker::MdnsRecordTracker(
132     MdnsRecord record,
133     DnsType dns_type,
134     MdnsSender* sender,
135     TaskRunner* task_runner,
136     ClockNowFunctionPtr now_function,
137     MdnsRandom* random_delay,
138     RecordExpiredCallback record_expired_callback)
139     : MdnsTracker(sender,
140                   task_runner,
141                   now_function,
142                   random_delay,
143                   TrackerType::kRecordTracker),
144       record_(std::move(record)),
145       dns_type_(dns_type),
146       start_time_(now_function_()),
147       record_expired_callback_(std::move(record_expired_callback)) {
148   OSP_DCHECK(record_expired_callback_);
149 
150   // RecordTrackers cannot be created for tracking NSEC types or ANY types.
151   OSP_DCHECK(dns_type_ != DnsType::kNSEC);
152   OSP_DCHECK(dns_type_ != DnsType::kANY);
153 
154   // Validate that, if the provided |record| is an NSEC record, then it provides
155   // a negative response for |dns_type|.
156   OSP_DCHECK(record_.dns_type() != DnsType::kNSEC ||
157              IsNegativeResponseForType(record_, dns_type_));
158 
159   ScheduleFollowUpQuery();
160 }
161 
162 MdnsRecordTracker::~MdnsRecordTracker() = default;
163 
Update(const MdnsRecord & new_record)164 ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
165     const MdnsRecord& new_record) {
166   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
167   const bool has_same_rdata = record_.dns_type() == new_record.dns_type() &&
168                               record_.rdata() == new_record.rdata();
169   const bool new_is_negative_response = new_record.dns_type() == DnsType::kNSEC;
170   const bool current_is_negative_response =
171       record_.dns_type() == DnsType::kNSEC;
172 
173   if ((record_.dns_class() != new_record.dns_class()) ||
174       (record_.name() != new_record.name())) {
175     // The new record has been passed to a wrong tracker.
176     return Error::Code::kParameterInvalid;
177   }
178 
179   // New response record must correspond to the correct type.
180   if ((!new_is_negative_response && new_record.dns_type() != dns_type_) ||
181       (new_is_negative_response &&
182        !IsNegativeResponseForType(new_record, dns_type_))) {
183     // The new record has been passed to a wrong tracker.
184     return Error::Code::kParameterInvalid;
185   }
186 
187   // Goodbye records must have the same RDATA but TTL of 0.
188   // RFC 6762 Section 10.1.
189   // https://tools.ietf.org/html/rfc6762#section-10.1
190   if (!new_is_negative_response && !current_is_negative_response &&
191       IsGoodbyeRecord(new_record) && !has_same_rdata) {
192     // The new record has been passed to a wrong tracker.
193     return Error::Code::kParameterInvalid;
194   }
195 
196   UpdateType result = UpdateType::kGoodbye;
197   if (IsGoodbyeRecord(new_record)) {
198     record_ = MdnsRecord(new_record.name(), new_record.dns_type(),
199                          new_record.dns_class(), new_record.record_type(),
200                          kGoodbyeRecordTtl, new_record.rdata());
201 
202     // Goodbye records do not need to be re-queried, set the attempt count to
203     // the last item, which is 100% of TTL, i.e. record expiration.
204     attempt_count_ = countof(kTtlFractions) - 1;
205   } else {
206     record_ = new_record;
207     attempt_count_ = 0;
208     result = has_same_rdata ? UpdateType::kTTLOnly : UpdateType::kRdata;
209   }
210 
211   start_time_ = now_function_();
212   ScheduleFollowUpQuery();
213 
214   return result;
215 }
216 
AddAssociatedQuery(const MdnsQuestionTracker * question_tracker) const217 bool MdnsRecordTracker::AddAssociatedQuery(
218     const MdnsQuestionTracker* question_tracker) const {
219   return AddAdjacentNode(question_tracker);
220 }
221 
RemoveAssociatedQuery(const MdnsQuestionTracker * question_tracker) const222 bool MdnsRecordTracker::RemoveAssociatedQuery(
223     const MdnsQuestionTracker* question_tracker) const {
224   return RemoveAdjacentNode(question_tracker);
225 }
226 
ExpireSoon()227 void MdnsRecordTracker::ExpireSoon() {
228   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
229 
230   record_ =
231       MdnsRecord(record_.name(), record_.dns_type(), record_.dns_class(),
232                  record_.record_type(), kGoodbyeRecordTtl, record_.rdata());
233 
234   // Set the attempt count to the last item, which is 100% of TTL, i.e. record
235   // expiration, to prevent any re-queries
236   attempt_count_ = countof(kTtlFractions) - 1;
237   start_time_ = now_function_();
238   ScheduleFollowUpQuery();
239 }
240 
ExpireNow()241 void MdnsRecordTracker::ExpireNow() {
242   record_expired_callback_(this, record_);
243 }
244 
IsNearingExpiry() const245 bool MdnsRecordTracker::IsNearingExpiry() const {
246   return (now_function_() - start_time_) > record_.ttl() / 2;
247 }
248 
SendQuery() const249 bool MdnsRecordTracker::SendQuery() const {
250   const Clock::time_point expiration_time = start_time_ + record_.ttl();
251   bool is_expired = (now_function_() >= expiration_time);
252   if (!is_expired) {
253     for (const MdnsTracker* tracker : adjacent_nodes()) {
254       tracker->SendQuery();
255     }
256   } else {
257     record_expired_callback_(this, record_);
258   }
259 
260   return !is_expired;
261 }
262 
ScheduleFollowUpQuery()263 void MdnsRecordTracker::ScheduleFollowUpQuery() {
264   send_alarm_.Schedule(
265       [this] {
266         if (SendQuery()) {
267           ScheduleFollowUpQuery();
268         }
269       },
270       GetNextSendTime());
271 }
272 
GetRecords() const273 std::vector<MdnsRecord> MdnsRecordTracker::GetRecords() const {
274   return {record_};
275 }
276 
GetNextSendTime()277 Clock::time_point MdnsRecordTracker::GetNextSendTime() {
278   OSP_DCHECK(attempt_count_ < countof(kTtlFractions));
279 
280   double ttl_fraction = kTtlFractions[attempt_count_++];
281 
282   // Do not add random variation to the expiration time (last fraction of TTL)
283   if (attempt_count_ != countof(kTtlFractions)) {
284     ttl_fraction += random_delay_->GetRecordTtlVariation();
285   }
286 
287   const Clock::duration delay =
288       Clock::to_duration(record_.ttl() * ttl_fraction);
289   return start_time_ + delay;
290 }
291 
MdnsQuestionTracker(MdnsQuestion question,MdnsSender * sender,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,const Config & config,QueryType query_type)292 MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question,
293                                          MdnsSender* sender,
294                                          TaskRunner* task_runner,
295                                          ClockNowFunctionPtr now_function,
296                                          MdnsRandom* random_delay,
297                                          const Config& config,
298                                          QueryType query_type)
299     : MdnsTracker(sender,
300                   task_runner,
301                   now_function,
302                   random_delay,
303                   TrackerType::kQuestionTracker),
304       question_(std::move(question)),
305       send_delay_(kMinimumQueryInterval),
306       query_type_(query_type),
307       maximum_announcement_count_(config.new_query_announcement_count < 0
308                                       ? INT_MAX
309                                       : config.new_query_announcement_count) {
310   // Initialize the last send time to time_point::min() so that the next call to
311   // SendQuery() is guaranteed to query the network.
312   last_send_time_ = TrivialClockTraits::time_point::min();
313 
314   // The initial query has to be sent after a random delay of 20-120
315   // milliseconds.
316   if (announcements_so_far_ < maximum_announcement_count_) {
317     announcements_so_far_++;
318 
319     if (query_type_ == QueryType::kOneShot) {
320       task_runner_->PostTask([this] { MdnsQuestionTracker::SendQuery(); });
321     } else {
322       OSP_DCHECK(query_type_ == QueryType::kContinuous);
323       send_alarm_.ScheduleFromNow(
324           [this]() {
325             MdnsQuestionTracker::SendQuery();
326             ScheduleFollowUpQuery();
327           },
328           random_delay_->GetInitialQueryDelay());
329     }
330   }
331 }
332 
333 MdnsQuestionTracker::~MdnsQuestionTracker() = default;
334 
AddAssociatedRecord(const MdnsRecordTracker * record_tracker) const335 bool MdnsQuestionTracker::AddAssociatedRecord(
336     const MdnsRecordTracker* record_tracker) const {
337   return AddAdjacentNode(record_tracker);
338 }
339 
RemoveAssociatedRecord(const MdnsRecordTracker * record_tracker) const340 bool MdnsQuestionTracker::RemoveAssociatedRecord(
341     const MdnsRecordTracker* record_tracker) const {
342   return RemoveAdjacentNode(record_tracker);
343 }
344 
GetRecords() const345 std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const {
346   std::vector<MdnsRecord> records;
347   for (const MdnsTracker* tracker : adjacent_nodes()) {
348     OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker);
349 
350     // This call cannot result in an infinite loop because MdnsRecordTracker
351     // instances only return a single record from this call.
352     std::vector<MdnsRecord> node_records = tracker->GetRecords();
353     OSP_DCHECK(node_records.size() == 1);
354 
355     records.push_back(std::move(node_records[0]));
356   }
357 
358   return records;
359 }
360 
SendQuery() const361 bool MdnsQuestionTracker::SendQuery() const {
362   // NOTE: The RFC does not specify the minimum interval between queries for
363   // multiple records of the same query when initiated for different reasons
364   // (such as for different record refreshes or for one record refresh and the
365   // periodic re-querying for a continuous query). For this reason, a constant
366   // outside of scope of the RFC has been chosen.
367   TrivialClockTraits::time_point now = now_function_();
368   if (now < last_send_time_ + kMinimumQueryInterval) {
369     return true;
370   }
371   last_send_time_ = now;
372 
373   MdnsMessage message(CreateMessageId(), MessageType::Query);
374   message.AddQuestion(question_);
375 
376   // Send the message and additional known answer packets as needed.
377   for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) {
378     OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker);
379 
380     const MdnsRecordTracker* record_tracker =
381         static_cast<const MdnsRecordTracker*>(*it);
382     if (record_tracker->IsNearingExpiry()) {
383       it++;
384       continue;
385     }
386 
387     // A record tracker should only contain one record.
388     std::vector<MdnsRecord> node_records = (*it)->GetRecords();
389     OSP_DCHECK(node_records.size() == 1);
390     MdnsRecord node_record = std::move(node_records[0]);
391 
392     if (message.CanAddRecord(node_record)) {
393       message.AddAnswer(std::move(node_record));
394       it++;
395     } else if (message.questions().empty() && message.answers().empty()) {
396       // This case should never happen, because it means a record is too large
397       // to fit into its own message.
398       OSP_LOG_INFO
399           << "Encountered unreasonably large message in cache. Skipping "
400           << "known answer in suppressions...";
401       it++;
402     } else {
403       message.set_truncated();
404       sender_->SendMulticast(message);
405       message = MdnsMessage(CreateMessageId(), MessageType::Query);
406     }
407   }
408   sender_->SendMulticast(message);
409   return true;
410 }
411 
ScheduleFollowUpQuery()412 void MdnsQuestionTracker::ScheduleFollowUpQuery() {
413   if (announcements_so_far_ >= maximum_announcement_count_) {
414     return;
415   }
416   announcements_so_far_++;
417 
418   send_alarm_.ScheduleFromNow(
419       [this] {
420         if (SendQuery()) {
421           ScheduleFollowUpQuery();
422         }
423       },
424       send_delay_);
425   send_delay_ = send_delay_ * kIntervalIncreaseFactor;
426   if (send_delay_ > kMaximumQueryInterval) {
427     send_delay_ = kMaximumQueryInterval;
428   }
429 }
430 
431 }  // namespace discovery
432 }  // namespace openscreen
433