1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_querier.h"
6 
7 #include <vector>
8 
9 #include "discovery/common/config.h"
10 #include "discovery/common/reporting_client.h"
11 #include "discovery/mdns/mdns_random.h"
12 #include "discovery/mdns/mdns_receiver.h"
13 #include "discovery/mdns/mdns_sender.h"
14 
15 namespace openscreen {
16 namespace discovery {
17 namespace {
18 
19 const std::vector<DnsType> kTranslatedNsecAnyQueryTypes = {
20     DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
21 
IsNegativeResponseFor(const MdnsRecord & record,DnsType type)22 bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
23   if (record.dns_type() != DnsType::kNSEC) {
24     return false;
25   }
26 
27   const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
28 
29   // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
30   // record to indicate this is an mDNS NSEC record rather than a traditional
31   // DNS NSEC record.
32   if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
33       nsec.types().end()) {
34     return false;
35   }
36 
37   return std::find_if(nsec.types().begin(), nsec.types().end(),
38                       [type](DnsType stored_type) {
39                         return stored_type == type ||
40                                stored_type == DnsType::kANY;
41                       }) != nsec.types().end();
42 }
43 
44 }  // namespace
45 
RecordTrackerLruCache(MdnsQuerier * querier,MdnsSender * sender,MdnsRandom * random_delay,TaskRunner * task_runner,ClockNowFunctionPtr now_function,ReportingClient * reporting_client,const Config & config)46 MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
47     MdnsQuerier* querier,
48     MdnsSender* sender,
49     MdnsRandom* random_delay,
50     TaskRunner* task_runner,
51     ClockNowFunctionPtr now_function,
52     ReportingClient* reporting_client,
53     const Config& config)
54     : querier_(querier),
55       sender_(sender),
56       random_delay_(random_delay),
57       task_runner_(task_runner),
58       now_function_(now_function),
59       reporting_client_(reporting_client),
60       config_(config) {
61   OSP_DCHECK(sender_);
62   OSP_DCHECK(random_delay_);
63   OSP_DCHECK(task_runner_);
64   OSP_DCHECK(reporting_client_);
65   OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
66 }
67 
68 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name)69 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
70   return Find(name, DnsType::kANY, DnsClass::kANY);
71 }
72 
73 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name,DnsType dns_type,DnsClass dns_class)74 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
75                                          DnsType dns_type,
76                                          DnsClass dns_class) {
77   std::vector<RecordTrackerConstRef> results;
78   auto pair = records_.equal_range(name);
79   for (auto it = pair.first; it != pair.second; it++) {
80     const MdnsRecordTracker& tracker = *it->second;
81     if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
82         (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
83       results.push_back(std::cref(tracker));
84     }
85   }
86 
87   return results;
88 }
89 
Erase(const DomainName & domain,TrackerApplicableCheck check)90 int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
91                                               TrackerApplicableCheck check) {
92   auto pair = records_.equal_range(domain);
93   int count = 0;
94   for (RecordMap::iterator it = pair.first; it != pair.second;) {
95     if (check(*it->second)) {
96       lru_order_.erase(it->second);
97       it = records_.erase(it);
98       count++;
99     } else {
100       it++;
101     }
102   }
103 
104   return count;
105 }
106 
ExpireSoon(const DomainName & domain,TrackerApplicableCheck check)107 int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
108     const DomainName& domain,
109     TrackerApplicableCheck check) {
110   auto pair = records_.equal_range(domain);
111   int count = 0;
112   for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
113     if (check(*it->second)) {
114       MoveToEnd(it);
115       it->second->ExpireSoon();
116       count++;
117     }
118   }
119 
120   return count;
121 }
122 
Update(const MdnsRecord & record,TrackerApplicableCheck check)123 int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
124                                                TrackerApplicableCheck check) {
125   return Update(record, check, [](const MdnsRecordTracker& t) {});
126 }
127 
Update(const MdnsRecord & record,TrackerApplicableCheck check,TrackerChangeCallback on_rdata_update)128 int MdnsQuerier::RecordTrackerLruCache::Update(
129     const MdnsRecord& record,
130     TrackerApplicableCheck check,
131     TrackerChangeCallback on_rdata_update) {
132   auto pair = records_.equal_range(record.name());
133   int count = 0;
134   for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
135     if (check(*it->second)) {
136       auto result = it->second->Update(record);
137 
138       if (result.is_error()) {
139         reporting_client_->OnRecoverableError(
140             Error(Error::Code::kUpdateReceivedRecordFailure,
141                   result.error().ToString()));
142         continue;
143       }
144 
145       count++;
146       if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
147         it->second->ExpireSoon();
148         MoveToEnd(it);
149       } else {
150         MoveToBeginning(it);
151         if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
152           on_rdata_update(*it->second);
153         }
154       }
155     }
156   }
157 
158   return count;
159 }
160 
StartTracking(MdnsRecord record,DnsType dns_type)161 const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
162     MdnsRecord record,
163     DnsType dns_type) {
164   auto expiration_callback = [this](const MdnsRecordTracker* tracker,
165                                     const MdnsRecord& record) {
166     querier_->OnRecordExpired(tracker, record);
167   };
168 
169   while (lru_order_.size() >=
170          static_cast<size_t>(config_.querier_max_records_cached)) {
171     // This call erases one of the tracked records.
172     OSP_DVLOG << "Maximum cacheable record count exceeded ("
173               << config_.querier_max_records_cached << ")";
174     lru_order_.back().ExpireNow();
175   }
176 
177   auto name = record.name();
178   lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
179                            now_function_, random_delay_,
180                            std::move(expiration_callback));
181   records_.emplace(std::move(name), lru_order_.begin());
182 
183   return lru_order_.front();
184 }
185 
MoveToBeginning(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)186 void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
187     MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
188   lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
189   it->second = lru_order_.begin();
190 }
191 
MoveToEnd(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)192 void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
193     MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
194   lru_order_.splice(lru_order_.end(), lru_order_, it->second);
195   it->second = --lru_order_.end();
196 }
197 
MdnsQuerier(MdnsSender * sender,MdnsReceiver * receiver,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,ReportingClient * reporting_client,Config config)198 MdnsQuerier::MdnsQuerier(MdnsSender* sender,
199                          MdnsReceiver* receiver,
200                          TaskRunner* task_runner,
201                          ClockNowFunctionPtr now_function,
202                          MdnsRandom* random_delay,
203                          ReportingClient* reporting_client,
204                          Config config)
205     : sender_(sender),
206       receiver_(receiver),
207       task_runner_(task_runner),
208       now_function_(now_function),
209       random_delay_(random_delay),
210       reporting_client_(reporting_client),
211       config_(std::move(config)),
212       records_(this,
213                sender_,
214                random_delay_,
215                task_runner_,
216                now_function_,
217                reporting_client_,
218                config_) {
219   OSP_DCHECK(sender_);
220   OSP_DCHECK(receiver_);
221   OSP_DCHECK(task_runner_);
222   OSP_DCHECK(now_function_);
223   OSP_DCHECK(random_delay_);
224   OSP_DCHECK(reporting_client_);
225 
226   receiver_->AddResponseCallback(this);
227 }
228 
~MdnsQuerier()229 MdnsQuerier::~MdnsQuerier() {
230   receiver_->RemoveResponseCallback(this);
231 }
232 
233 // NOTE: The code below is range loops instead of std:find_if, for better
234 // readability, brevity and homogeneity.  Using std::find_if results in a few
235 // more lines of code, readability suffers from extra lambdas.
236 
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)237 void MdnsQuerier::StartQuery(const DomainName& name,
238                              DnsType dns_type,
239                              DnsClass dns_class,
240                              MdnsRecordChangedCallback* callback) {
241   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
242   OSP_DCHECK(callback);
243   OSP_DCHECK(dns_type != DnsType::kNSEC);
244 
245   // Add a new callback if haven't seen it before
246   auto callbacks_it = callbacks_.equal_range(name);
247   for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
248     const CallbackInfo& callback_info = entry->second;
249     if (dns_type == callback_info.dns_type &&
250         dns_class == callback_info.dns_class &&
251         callback == callback_info.callback) {
252       // Already have this callback
253       return;
254     }
255   }
256   callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
257 
258   // Notify the new callback with previously cached records.
259   // NOTE: In the future, could allow callers to fetch cached records after
260   // adding a callback, for example to prime the UI.
261   const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
262       records_.Find(name, dns_type, dns_class);
263   for (const MdnsRecordTracker& tracker : trackers) {
264     if (!tracker.is_negative_response()) {
265       MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
266                                tracker.record_type(), tracker.ttl(),
267                                tracker.rdata());
268       callback->OnRecordChanged(std::move(stored_record),
269                                 RecordChangedEvent::kCreated);
270     }
271   }
272 
273   // Add a new question if haven't seen it before
274   auto questions_it = questions_.equal_range(name);
275   for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
276     const MdnsQuestion& tracked_question = entry->second->question();
277     if (dns_type == tracked_question.dns_type() &&
278         dns_class == tracked_question.dns_class()) {
279       // Already have this question
280       return;
281     }
282   }
283   AddQuestion(
284       MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
285 }
286 
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)287 void MdnsQuerier::StopQuery(const DomainName& name,
288                             DnsType dns_type,
289                             DnsClass dns_class,
290                             MdnsRecordChangedCallback* callback) {
291   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
292   OSP_DCHECK(callback);
293   OSP_DCHECK(dns_type != DnsType::kNSEC);
294 
295   // Find and remove the callback.
296   int callbacks_for_key = 0;
297   auto callbacks_it = callbacks_.equal_range(name);
298   for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
299     const CallbackInfo& callback_info = entry->second;
300     if (dns_type == callback_info.dns_type &&
301         dns_class == callback_info.dns_class) {
302       if (callback == callback_info.callback) {
303         entry = callbacks_.erase(entry);
304       } else {
305         ++callbacks_for_key;
306         ++entry;
307       }
308     }
309   }
310 
311   // Exit if there are still callbacks registered for DomainName + DnsType +
312   // DnsClass
313   if (callbacks_for_key > 0) {
314     return;
315   }
316 
317   // Find and delete a question that does not have any associated callbacks
318   auto questions_it = questions_.equal_range(name);
319   for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
320     const MdnsQuestion& tracked_question = entry->second->question();
321     if (dns_type == tracked_question.dns_type() &&
322         dns_class == tracked_question.dns_class()) {
323       questions_.erase(entry);
324       return;
325     }
326   }
327 
328   // TODO(crbug.com/openscreen/83): Find and delete all records that no longer
329   // answer any questions, if a question was deleted.  It's possible the same
330   // query will be added back before the records expire, so this behavior could
331   // be configurable by the caller.
332 }
333 
ReinitializeQueries(const DomainName & name)334 void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
335   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
336 
337   // Get the ongoing queries and their callbacks.
338   std::vector<CallbackInfo> callbacks;
339   auto its = callbacks_.equal_range(name);
340   for (auto it = its.first; it != its.second; it++) {
341     callbacks.push_back(std::move(it->second));
342   }
343   callbacks_.erase(name);
344 
345   // Remove all known questions and answers.
346   questions_.erase(name);
347   records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
348 
349   // Restart the queries.
350   for (const auto& cb : callbacks) {
351     StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
352   }
353 }
354 
OnMessageReceived(const MdnsMessage & message)355 void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
356   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
357   OSP_DCHECK(message.type() == MessageType::Response);
358 
359   OSP_DVLOG << "Received mDNS Response message with "
360             << message.answers().size() << " answers and "
361             << message.additional_records().size()
362             << " additional records. Processing...";
363 
364   // Add any records that are relevant for this querier.
365   bool found_relevant_records = false;
366   int processed_count = 0;
367   for (const MdnsRecord& record : message.answers()) {
368     if (ShouldAnswerRecordBeProcessed(record)) {
369       ProcessRecord(record);
370       OSP_DVLOG << "\tProcessing answer record for domain '"
371                 << record.name().ToString() << "' of type '"
372                 << record.dns_type() << "'...";
373       found_relevant_records = true;
374       processed_count++;
375     }
376   }
377 
378   // If any of the message's answers are relevant, add all additional records.
379   // Else, since the message has already been received and parsed, use any
380   // individual records relevant to this querier to update the cache.
381   for (const MdnsRecord& record : message.additional_records()) {
382     if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
383       OSP_DVLOG << "\tProcessing additional record for domain '"
384                 << record.name().ToString() << "' of type '"
385                 << record.dns_type() << "'...";
386       ProcessRecord(record);
387       processed_count++;
388     }
389   }
390 
391   OSP_DVLOG << "\tmDNS Response processed (" << processed_count
392             << " records accepted)!";
393 
394   // TODO(crbug.com/openscreen/83): Check authority records.
395 }
396 
ShouldAnswerRecordBeProcessed(const MdnsRecord & answer)397 bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
398   // First, accept the record if it's associated with an ongoing question.
399   const auto questions_range = questions_.equal_range(answer.name());
400   const auto it = std::find_if(
401       questions_range.first, questions_range.second,
402       [&answer](const auto& pair) {
403         return (pair.second->question().dns_type() == DnsType::kANY ||
404                 IsNegativeResponseFor(answer,
405                                       pair.second->question().dns_type()) ||
406                 pair.second->question().dns_type() == answer.dns_type()) &&
407                (pair.second->question().dns_class() == DnsClass::kANY ||
408                 pair.second->question().dns_class() == answer.dns_class());
409       });
410   if (it != questions_range.second) {
411     return true;
412   }
413 
414   // If not, check if it corresponds to an already existing record. This is
415   // required because records which are already stored may either have been
416   // received in an additional records section, or are associated with a query
417   // which is no longer active.
418   std::vector<DnsType> types{answer.dns_type()};
419   if (answer.dns_type() == DnsType::kNSEC) {
420     const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
421     types = nsec_rdata.types();
422   }
423 
424   for (DnsType type : types) {
425     std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
426         records_.Find(answer.name(), type, answer.dns_class());
427     if (!trackers.empty()) {
428       return true;
429     }
430   }
431 
432   // In all other cases, the record isn't relevant. Drop it.
433   return false;
434 }
435 
OnRecordExpired(const MdnsRecordTracker * tracker,const MdnsRecord & record)436 void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
437                                   const MdnsRecord& record) {
438   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
439 
440   if (!tracker->is_negative_response()) {
441     ProcessCallbacks(record, RecordChangedEvent::kExpired);
442   }
443 
444   records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
445     return tracker == &it_tracker;
446   });
447 }
448 
ProcessRecord(const MdnsRecord & record)449 void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
450   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
451 
452   // Get the types which the received record is associated with. In most cases
453   // this will only be the type of the provided record, but in the case of
454   // NSEC records this will be all records which the record dictates the
455   // nonexistence of.
456   std::vector<DnsType> types;
457   const std::vector<DnsType>* types_ptr = &types;
458   if (record.dns_type() == DnsType::kNSEC) {
459     const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
460     if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
461                   DnsType::kANY) != nsec_rdata.types().end()) {
462       types_ptr = &kTranslatedNsecAnyQueryTypes;
463     } else {
464       types_ptr = &nsec_rdata.types();
465     }
466   } else {
467     types.push_back(record.dns_type());
468   }
469 
470   // Apply the update for each type that the record is associated with.
471   for (DnsType dns_type : *types_ptr) {
472     switch (record.record_type()) {
473       case RecordType::kShared: {
474         ProcessSharedRecord(record, dns_type);
475         break;
476       }
477       case RecordType::kUnique: {
478         ProcessUniqueRecord(record, dns_type);
479         break;
480       }
481     }
482   }
483 }
484 
ProcessSharedRecord(const MdnsRecord & record,DnsType dns_type)485 void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
486                                       DnsType dns_type) {
487   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
488   OSP_DCHECK(record.record_type() == RecordType::kShared);
489 
490   // By design, NSEC records are never shared records.
491   if (record.dns_type() == DnsType::kNSEC) {
492     return;
493   }
494 
495   // For any records updated, this host already has this shared record. Since
496   // the RDATA matches, this is only a TTL update.
497   auto check = [&record](const MdnsRecordTracker& tracker) {
498     return record.dns_type() == tracker.dns_type() &&
499            record.dns_class() == tracker.dns_class() &&
500            record.rdata() == tracker.rdata();
501   };
502   auto updated_count = records_.Update(record, std::move(check));
503 
504   if (!updated_count) {
505     // Have never before seen this shared record, insert a new one.
506     AddRecord(record, dns_type);
507     ProcessCallbacks(record, RecordChangedEvent::kCreated);
508   }
509 }
510 
ProcessUniqueRecord(const MdnsRecord & record,DnsType dns_type)511 void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
512                                       DnsType dns_type) {
513   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
514   OSP_DCHECK(record.record_type() == RecordType::kUnique);
515 
516   std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
517       records_.Find(record.name(), dns_type, record.dns_class());
518   size_t num_records_for_key = trackers.size();
519 
520   // Have not seen any records with this key before. This case is expected the
521   // first time a record is received.
522   if (num_records_for_key == size_t{0}) {
523     const bool will_exist = record.dns_type() != DnsType::kNSEC;
524     AddRecord(record, dns_type);
525     if (will_exist) {
526       ProcessCallbacks(record, RecordChangedEvent::kCreated);
527     }
528   }
529 
530   // There is exactly one tracker associated with this key. This is the expected
531   // case when a record matching this one has already been seen.
532   else if (num_records_for_key == size_t{1}) {
533     ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
534   }
535 
536   // Multiple records with the same key.
537   else {
538     ProcessMultiTrackedUniqueRecord(record, dns_type);
539   }
540 }
541 
ProcessSinglyTrackedUniqueRecord(const MdnsRecord & record,const MdnsRecordTracker & tracker)542 void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
543     const MdnsRecord& record,
544     const MdnsRecordTracker& tracker) {
545   const bool existed_previously = !tracker.is_negative_response();
546   const bool will_exist = record.dns_type() != DnsType::kNSEC;
547 
548   // Calculate the callback to call on record update success while the old
549   // record still exists.
550   MdnsRecord record_for_callback = record;
551   if (existed_previously && !will_exist) {
552     record_for_callback =
553         MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
554                    tracker.record_type(), tracker.ttl(), tracker.rdata());
555   }
556 
557   auto on_rdata_change = [this, r = std::move(record_for_callback),
558                           existed_previously,
559                           will_exist](const MdnsRecordTracker& tracker) {
560     // If RDATA on the record is different, notify that the record has
561     // been updated.
562     if (existed_previously && will_exist) {
563       ProcessCallbacks(r, RecordChangedEvent::kUpdated);
564     } else if (existed_previously) {
565       // Do not expire the tracker, because it still holds an NSEC record.
566       ProcessCallbacks(r, RecordChangedEvent::kExpired);
567     } else if (will_exist) {
568       ProcessCallbacks(r, RecordChangedEvent::kCreated);
569     }
570   };
571 
572   int updated_count = records_.Update(
573       record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
574       std::move(on_rdata_change));
575   OSP_DCHECK_EQ(updated_count, 1);
576 }
577 
ProcessMultiTrackedUniqueRecord(const MdnsRecord & record,DnsType dns_type)578 void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
579                                                   DnsType dns_type) {
580   auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
581     return tracker.dns_type() == dns_type &&
582            tracker.dns_class() == record.dns_class() &&
583            tracker.rdata() == record.rdata();
584   };
585   int update_count = records_.Update(
586       record, std::move(update_check),
587       [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
588   OSP_DCHECK_LE(update_count, 1);
589 
590   auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
591     return tracker.dns_type() == dns_type &&
592            tracker.dns_class() == record.dns_class() &&
593            tracker.rdata() != record.rdata();
594   };
595   int expire_count =
596       records_.ExpireSoon(record.name(), std::move(expire_check));
597   OSP_DCHECK_GE(expire_count, 1);
598 
599   // Did not find an existing record to update.
600   if (!update_count && !expire_count) {
601     AddRecord(record, dns_type);
602     if (record.dns_type() != DnsType::kNSEC) {
603       ProcessCallbacks(record, RecordChangedEvent::kCreated);
604     }
605   }
606 }
607 
ProcessCallbacks(const MdnsRecord & record,RecordChangedEvent event)608 void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
609                                    RecordChangedEvent event) {
610   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
611 
612   auto callbacks_it = callbacks_.equal_range(record.name());
613   for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
614     const CallbackInfo& callback_info = entry->second;
615     if ((callback_info.dns_type == DnsType::kANY ||
616          record.dns_type() == callback_info.dns_type) &&
617         (callback_info.dns_class == DnsClass::kANY ||
618          record.dns_class() == callback_info.dns_class)) {
619       callback_info.callback->OnRecordChanged(record, event);
620     }
621   }
622 }
623 
AddQuestion(const MdnsQuestion & question)624 void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
625   auto tracker = std::make_unique<MdnsQuestionTracker>(
626       std::move(question), sender_, task_runner_, now_function_, random_delay_,
627       config_);
628   MdnsQuestionTracker* ptr = tracker.get();
629   questions_.emplace(question.name(), std::move(tracker));
630 
631   // Let all records associated with this question know that there is a new
632   // query that can be used for their refresh.
633   std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
634       records_.Find(question.name(), question.dns_type(), question.dns_class());
635   for (const MdnsRecordTracker& tracker : trackers) {
636     // NOTE: When the pointed to object is deleted, its dtor removes itself
637     // from all associated records.
638     ptr->AddAssociatedRecord(&tracker);
639   }
640 }
641 
AddRecord(const MdnsRecord & record,DnsType type)642 void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
643   // Add the new record.
644   const auto& tracker = records_.StartTracking(record, type);
645 
646   // Let all questions associated with this record know that there is a new
647   // record that answers them (for known answer suppression).
648   auto query_it = questions_.equal_range(record.name());
649   for (auto entry = query_it.first; entry != query_it.second; ++entry) {
650     const MdnsQuestion& query = entry->second->question();
651     const bool is_relevant_type =
652         type == DnsType::kANY || type == query.dns_type();
653     const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
654                                    record.dns_class() == query.dns_class();
655     if (is_relevant_type && is_relevant_class) {
656       // NOTE: When the pointed to object is deleted, its dtor removes itself
657       // from all associated queries.
658       entry->second->AddAssociatedRecord(&tracker);
659     }
660   }
661 }
662 
663 }  // namespace discovery
664 }  // namespace openscreen
665