1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_querier.h"
6
7 #include <vector>
8
9 #include "discovery/common/config.h"
10 #include "discovery/common/reporting_client.h"
11 #include "discovery/mdns/mdns_random.h"
12 #include "discovery/mdns/mdns_receiver.h"
13 #include "discovery/mdns/mdns_sender.h"
14
15 namespace openscreen {
16 namespace discovery {
17 namespace {
18
19 const std::vector<DnsType> kTranslatedNsecAnyQueryTypes = {
20 DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
21
IsNegativeResponseFor(const MdnsRecord & record,DnsType type)22 bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
23 if (record.dns_type() != DnsType::kNSEC) {
24 return false;
25 }
26
27 const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
28
29 // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
30 // record to indicate this is an mDNS NSEC record rather than a traditional
31 // DNS NSEC record.
32 if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
33 nsec.types().end()) {
34 return false;
35 }
36
37 return std::find_if(nsec.types().begin(), nsec.types().end(),
38 [type](DnsType stored_type) {
39 return stored_type == type ||
40 stored_type == DnsType::kANY;
41 }) != nsec.types().end();
42 }
43
44 } // namespace
45
RecordTrackerLruCache(MdnsQuerier * querier,MdnsSender * sender,MdnsRandom * random_delay,TaskRunner * task_runner,ClockNowFunctionPtr now_function,ReportingClient * reporting_client,const Config & config)46 MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
47 MdnsQuerier* querier,
48 MdnsSender* sender,
49 MdnsRandom* random_delay,
50 TaskRunner* task_runner,
51 ClockNowFunctionPtr now_function,
52 ReportingClient* reporting_client,
53 const Config& config)
54 : querier_(querier),
55 sender_(sender),
56 random_delay_(random_delay),
57 task_runner_(task_runner),
58 now_function_(now_function),
59 reporting_client_(reporting_client),
60 config_(config) {
61 OSP_DCHECK(sender_);
62 OSP_DCHECK(random_delay_);
63 OSP_DCHECK(task_runner_);
64 OSP_DCHECK(reporting_client_);
65 OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
66 }
67
68 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name)69 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
70 return Find(name, DnsType::kANY, DnsClass::kANY);
71 }
72
73 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name,DnsType dns_type,DnsClass dns_class)74 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
75 DnsType dns_type,
76 DnsClass dns_class) {
77 std::vector<RecordTrackerConstRef> results;
78 auto pair = records_.equal_range(name);
79 for (auto it = pair.first; it != pair.second; it++) {
80 const MdnsRecordTracker& tracker = *it->second;
81 if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
82 (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
83 results.push_back(std::cref(tracker));
84 }
85 }
86
87 return results;
88 }
89
Erase(const DomainName & domain,TrackerApplicableCheck check)90 int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
91 TrackerApplicableCheck check) {
92 auto pair = records_.equal_range(domain);
93 int count = 0;
94 for (RecordMap::iterator it = pair.first; it != pair.second;) {
95 if (check(*it->second)) {
96 lru_order_.erase(it->second);
97 it = records_.erase(it);
98 count++;
99 } else {
100 it++;
101 }
102 }
103
104 return count;
105 }
106
ExpireSoon(const DomainName & domain,TrackerApplicableCheck check)107 int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
108 const DomainName& domain,
109 TrackerApplicableCheck check) {
110 auto pair = records_.equal_range(domain);
111 int count = 0;
112 for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
113 if (check(*it->second)) {
114 MoveToEnd(it);
115 it->second->ExpireSoon();
116 count++;
117 }
118 }
119
120 return count;
121 }
122
Update(const MdnsRecord & record,TrackerApplicableCheck check)123 int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
124 TrackerApplicableCheck check) {
125 return Update(record, check, [](const MdnsRecordTracker& t) {});
126 }
127
Update(const MdnsRecord & record,TrackerApplicableCheck check,TrackerChangeCallback on_rdata_update)128 int MdnsQuerier::RecordTrackerLruCache::Update(
129 const MdnsRecord& record,
130 TrackerApplicableCheck check,
131 TrackerChangeCallback on_rdata_update) {
132 auto pair = records_.equal_range(record.name());
133 int count = 0;
134 for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
135 if (check(*it->second)) {
136 auto result = it->second->Update(record);
137
138 if (result.is_error()) {
139 reporting_client_->OnRecoverableError(
140 Error(Error::Code::kUpdateReceivedRecordFailure,
141 result.error().ToString()));
142 continue;
143 }
144
145 count++;
146 if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
147 it->second->ExpireSoon();
148 MoveToEnd(it);
149 } else {
150 MoveToBeginning(it);
151 if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
152 on_rdata_update(*it->second);
153 }
154 }
155 }
156 }
157
158 return count;
159 }
160
StartTracking(MdnsRecord record,DnsType dns_type)161 const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
162 MdnsRecord record,
163 DnsType dns_type) {
164 auto expiration_callback = [this](const MdnsRecordTracker* tracker,
165 const MdnsRecord& record) {
166 querier_->OnRecordExpired(tracker, record);
167 };
168
169 while (lru_order_.size() >=
170 static_cast<size_t>(config_.querier_max_records_cached)) {
171 // This call erases one of the tracked records.
172 OSP_DVLOG << "Maximum cacheable record count exceeded ("
173 << config_.querier_max_records_cached << ")";
174 lru_order_.back().ExpireNow();
175 }
176
177 auto name = record.name();
178 lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
179 now_function_, random_delay_,
180 std::move(expiration_callback));
181 records_.emplace(std::move(name), lru_order_.begin());
182
183 return lru_order_.front();
184 }
185
MoveToBeginning(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)186 void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
187 MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
188 lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
189 it->second = lru_order_.begin();
190 }
191
MoveToEnd(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)192 void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
193 MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
194 lru_order_.splice(lru_order_.end(), lru_order_, it->second);
195 it->second = --lru_order_.end();
196 }
197
MdnsQuerier(MdnsSender * sender,MdnsReceiver * receiver,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,ReportingClient * reporting_client,Config config)198 MdnsQuerier::MdnsQuerier(MdnsSender* sender,
199 MdnsReceiver* receiver,
200 TaskRunner* task_runner,
201 ClockNowFunctionPtr now_function,
202 MdnsRandom* random_delay,
203 ReportingClient* reporting_client,
204 Config config)
205 : sender_(sender),
206 receiver_(receiver),
207 task_runner_(task_runner),
208 now_function_(now_function),
209 random_delay_(random_delay),
210 reporting_client_(reporting_client),
211 config_(std::move(config)),
212 records_(this,
213 sender_,
214 random_delay_,
215 task_runner_,
216 now_function_,
217 reporting_client_,
218 config_) {
219 OSP_DCHECK(sender_);
220 OSP_DCHECK(receiver_);
221 OSP_DCHECK(task_runner_);
222 OSP_DCHECK(now_function_);
223 OSP_DCHECK(random_delay_);
224 OSP_DCHECK(reporting_client_);
225
226 receiver_->AddResponseCallback(this);
227 }
228
~MdnsQuerier()229 MdnsQuerier::~MdnsQuerier() {
230 receiver_->RemoveResponseCallback(this);
231 }
232
233 // NOTE: The code below is range loops instead of std:find_if, for better
234 // readability, brevity and homogeneity. Using std::find_if results in a few
235 // more lines of code, readability suffers from extra lambdas.
236
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)237 void MdnsQuerier::StartQuery(const DomainName& name,
238 DnsType dns_type,
239 DnsClass dns_class,
240 MdnsRecordChangedCallback* callback) {
241 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
242 OSP_DCHECK(callback);
243 OSP_DCHECK(dns_type != DnsType::kNSEC);
244
245 // Add a new callback if haven't seen it before
246 auto callbacks_it = callbacks_.equal_range(name);
247 for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
248 const CallbackInfo& callback_info = entry->second;
249 if (dns_type == callback_info.dns_type &&
250 dns_class == callback_info.dns_class &&
251 callback == callback_info.callback) {
252 // Already have this callback
253 return;
254 }
255 }
256 callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
257
258 // Notify the new callback with previously cached records.
259 // NOTE: In the future, could allow callers to fetch cached records after
260 // adding a callback, for example to prime the UI.
261 const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
262 records_.Find(name, dns_type, dns_class);
263 for (const MdnsRecordTracker& tracker : trackers) {
264 if (!tracker.is_negative_response()) {
265 MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
266 tracker.record_type(), tracker.ttl(),
267 tracker.rdata());
268 callback->OnRecordChanged(std::move(stored_record),
269 RecordChangedEvent::kCreated);
270 }
271 }
272
273 // Add a new question if haven't seen it before
274 auto questions_it = questions_.equal_range(name);
275 for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
276 const MdnsQuestion& tracked_question = entry->second->question();
277 if (dns_type == tracked_question.dns_type() &&
278 dns_class == tracked_question.dns_class()) {
279 // Already have this question
280 return;
281 }
282 }
283 AddQuestion(
284 MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
285 }
286
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)287 void MdnsQuerier::StopQuery(const DomainName& name,
288 DnsType dns_type,
289 DnsClass dns_class,
290 MdnsRecordChangedCallback* callback) {
291 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
292 OSP_DCHECK(callback);
293 OSP_DCHECK(dns_type != DnsType::kNSEC);
294
295 // Find and remove the callback.
296 int callbacks_for_key = 0;
297 auto callbacks_it = callbacks_.equal_range(name);
298 for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
299 const CallbackInfo& callback_info = entry->second;
300 if (dns_type == callback_info.dns_type &&
301 dns_class == callback_info.dns_class) {
302 if (callback == callback_info.callback) {
303 entry = callbacks_.erase(entry);
304 } else {
305 ++callbacks_for_key;
306 ++entry;
307 }
308 }
309 }
310
311 // Exit if there are still callbacks registered for DomainName + DnsType +
312 // DnsClass
313 if (callbacks_for_key > 0) {
314 return;
315 }
316
317 // Find and delete a question that does not have any associated callbacks
318 auto questions_it = questions_.equal_range(name);
319 for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
320 const MdnsQuestion& tracked_question = entry->second->question();
321 if (dns_type == tracked_question.dns_type() &&
322 dns_class == tracked_question.dns_class()) {
323 questions_.erase(entry);
324 return;
325 }
326 }
327
328 // TODO(crbug.com/openscreen/83): Find and delete all records that no longer
329 // answer any questions, if a question was deleted. It's possible the same
330 // query will be added back before the records expire, so this behavior could
331 // be configurable by the caller.
332 }
333
ReinitializeQueries(const DomainName & name)334 void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
335 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
336
337 // Get the ongoing queries and their callbacks.
338 std::vector<CallbackInfo> callbacks;
339 auto its = callbacks_.equal_range(name);
340 for (auto it = its.first; it != its.second; it++) {
341 callbacks.push_back(std::move(it->second));
342 }
343 callbacks_.erase(name);
344
345 // Remove all known questions and answers.
346 questions_.erase(name);
347 records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
348
349 // Restart the queries.
350 for (const auto& cb : callbacks) {
351 StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
352 }
353 }
354
OnMessageReceived(const MdnsMessage & message)355 void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
356 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
357 OSP_DCHECK(message.type() == MessageType::Response);
358
359 OSP_DVLOG << "Received mDNS Response message with "
360 << message.answers().size() << " answers and "
361 << message.additional_records().size()
362 << " additional records. Processing...";
363
364 // Add any records that are relevant for this querier.
365 bool found_relevant_records = false;
366 int processed_count = 0;
367 for (const MdnsRecord& record : message.answers()) {
368 if (ShouldAnswerRecordBeProcessed(record)) {
369 ProcessRecord(record);
370 OSP_DVLOG << "\tProcessing answer record for domain '"
371 << record.name().ToString() << "' of type '"
372 << record.dns_type() << "'...";
373 found_relevant_records = true;
374 processed_count++;
375 }
376 }
377
378 // If any of the message's answers are relevant, add all additional records.
379 // Else, since the message has already been received and parsed, use any
380 // individual records relevant to this querier to update the cache.
381 for (const MdnsRecord& record : message.additional_records()) {
382 if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
383 OSP_DVLOG << "\tProcessing additional record for domain '"
384 << record.name().ToString() << "' of type '"
385 << record.dns_type() << "'...";
386 ProcessRecord(record);
387 processed_count++;
388 }
389 }
390
391 OSP_DVLOG << "\tmDNS Response processed (" << processed_count
392 << " records accepted)!";
393
394 // TODO(crbug.com/openscreen/83): Check authority records.
395 }
396
ShouldAnswerRecordBeProcessed(const MdnsRecord & answer)397 bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
398 // First, accept the record if it's associated with an ongoing question.
399 const auto questions_range = questions_.equal_range(answer.name());
400 const auto it = std::find_if(
401 questions_range.first, questions_range.second,
402 [&answer](const auto& pair) {
403 return (pair.second->question().dns_type() == DnsType::kANY ||
404 IsNegativeResponseFor(answer,
405 pair.second->question().dns_type()) ||
406 pair.second->question().dns_type() == answer.dns_type()) &&
407 (pair.second->question().dns_class() == DnsClass::kANY ||
408 pair.second->question().dns_class() == answer.dns_class());
409 });
410 if (it != questions_range.second) {
411 return true;
412 }
413
414 // If not, check if it corresponds to an already existing record. This is
415 // required because records which are already stored may either have been
416 // received in an additional records section, or are associated with a query
417 // which is no longer active.
418 std::vector<DnsType> types{answer.dns_type()};
419 if (answer.dns_type() == DnsType::kNSEC) {
420 const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
421 types = nsec_rdata.types();
422 }
423
424 for (DnsType type : types) {
425 std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
426 records_.Find(answer.name(), type, answer.dns_class());
427 if (!trackers.empty()) {
428 return true;
429 }
430 }
431
432 // In all other cases, the record isn't relevant. Drop it.
433 return false;
434 }
435
OnRecordExpired(const MdnsRecordTracker * tracker,const MdnsRecord & record)436 void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
437 const MdnsRecord& record) {
438 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
439
440 if (!tracker->is_negative_response()) {
441 ProcessCallbacks(record, RecordChangedEvent::kExpired);
442 }
443
444 records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
445 return tracker == &it_tracker;
446 });
447 }
448
ProcessRecord(const MdnsRecord & record)449 void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
450 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
451
452 // Get the types which the received record is associated with. In most cases
453 // this will only be the type of the provided record, but in the case of
454 // NSEC records this will be all records which the record dictates the
455 // nonexistence of.
456 std::vector<DnsType> types;
457 const std::vector<DnsType>* types_ptr = &types;
458 if (record.dns_type() == DnsType::kNSEC) {
459 const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
460 if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
461 DnsType::kANY) != nsec_rdata.types().end()) {
462 types_ptr = &kTranslatedNsecAnyQueryTypes;
463 } else {
464 types_ptr = &nsec_rdata.types();
465 }
466 } else {
467 types.push_back(record.dns_type());
468 }
469
470 // Apply the update for each type that the record is associated with.
471 for (DnsType dns_type : *types_ptr) {
472 switch (record.record_type()) {
473 case RecordType::kShared: {
474 ProcessSharedRecord(record, dns_type);
475 break;
476 }
477 case RecordType::kUnique: {
478 ProcessUniqueRecord(record, dns_type);
479 break;
480 }
481 }
482 }
483 }
484
ProcessSharedRecord(const MdnsRecord & record,DnsType dns_type)485 void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
486 DnsType dns_type) {
487 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
488 OSP_DCHECK(record.record_type() == RecordType::kShared);
489
490 // By design, NSEC records are never shared records.
491 if (record.dns_type() == DnsType::kNSEC) {
492 return;
493 }
494
495 // For any records updated, this host already has this shared record. Since
496 // the RDATA matches, this is only a TTL update.
497 auto check = [&record](const MdnsRecordTracker& tracker) {
498 return record.dns_type() == tracker.dns_type() &&
499 record.dns_class() == tracker.dns_class() &&
500 record.rdata() == tracker.rdata();
501 };
502 auto updated_count = records_.Update(record, std::move(check));
503
504 if (!updated_count) {
505 // Have never before seen this shared record, insert a new one.
506 AddRecord(record, dns_type);
507 ProcessCallbacks(record, RecordChangedEvent::kCreated);
508 }
509 }
510
ProcessUniqueRecord(const MdnsRecord & record,DnsType dns_type)511 void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
512 DnsType dns_type) {
513 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
514 OSP_DCHECK(record.record_type() == RecordType::kUnique);
515
516 std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
517 records_.Find(record.name(), dns_type, record.dns_class());
518 size_t num_records_for_key = trackers.size();
519
520 // Have not seen any records with this key before. This case is expected the
521 // first time a record is received.
522 if (num_records_for_key == size_t{0}) {
523 const bool will_exist = record.dns_type() != DnsType::kNSEC;
524 AddRecord(record, dns_type);
525 if (will_exist) {
526 ProcessCallbacks(record, RecordChangedEvent::kCreated);
527 }
528 }
529
530 // There is exactly one tracker associated with this key. This is the expected
531 // case when a record matching this one has already been seen.
532 else if (num_records_for_key == size_t{1}) {
533 ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
534 }
535
536 // Multiple records with the same key.
537 else {
538 ProcessMultiTrackedUniqueRecord(record, dns_type);
539 }
540 }
541
ProcessSinglyTrackedUniqueRecord(const MdnsRecord & record,const MdnsRecordTracker & tracker)542 void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
543 const MdnsRecord& record,
544 const MdnsRecordTracker& tracker) {
545 const bool existed_previously = !tracker.is_negative_response();
546 const bool will_exist = record.dns_type() != DnsType::kNSEC;
547
548 // Calculate the callback to call on record update success while the old
549 // record still exists.
550 MdnsRecord record_for_callback = record;
551 if (existed_previously && !will_exist) {
552 record_for_callback =
553 MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
554 tracker.record_type(), tracker.ttl(), tracker.rdata());
555 }
556
557 auto on_rdata_change = [this, r = std::move(record_for_callback),
558 existed_previously,
559 will_exist](const MdnsRecordTracker& tracker) {
560 // If RDATA on the record is different, notify that the record has
561 // been updated.
562 if (existed_previously && will_exist) {
563 ProcessCallbacks(r, RecordChangedEvent::kUpdated);
564 } else if (existed_previously) {
565 // Do not expire the tracker, because it still holds an NSEC record.
566 ProcessCallbacks(r, RecordChangedEvent::kExpired);
567 } else if (will_exist) {
568 ProcessCallbacks(r, RecordChangedEvent::kCreated);
569 }
570 };
571
572 int updated_count = records_.Update(
573 record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
574 std::move(on_rdata_change));
575 OSP_DCHECK_EQ(updated_count, 1);
576 }
577
ProcessMultiTrackedUniqueRecord(const MdnsRecord & record,DnsType dns_type)578 void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
579 DnsType dns_type) {
580 auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
581 return tracker.dns_type() == dns_type &&
582 tracker.dns_class() == record.dns_class() &&
583 tracker.rdata() == record.rdata();
584 };
585 int update_count = records_.Update(
586 record, std::move(update_check),
587 [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
588 OSP_DCHECK_LE(update_count, 1);
589
590 auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
591 return tracker.dns_type() == dns_type &&
592 tracker.dns_class() == record.dns_class() &&
593 tracker.rdata() != record.rdata();
594 };
595 int expire_count =
596 records_.ExpireSoon(record.name(), std::move(expire_check));
597 OSP_DCHECK_GE(expire_count, 1);
598
599 // Did not find an existing record to update.
600 if (!update_count && !expire_count) {
601 AddRecord(record, dns_type);
602 if (record.dns_type() != DnsType::kNSEC) {
603 ProcessCallbacks(record, RecordChangedEvent::kCreated);
604 }
605 }
606 }
607
ProcessCallbacks(const MdnsRecord & record,RecordChangedEvent event)608 void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
609 RecordChangedEvent event) {
610 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
611
612 auto callbacks_it = callbacks_.equal_range(record.name());
613 for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
614 const CallbackInfo& callback_info = entry->second;
615 if ((callback_info.dns_type == DnsType::kANY ||
616 record.dns_type() == callback_info.dns_type) &&
617 (callback_info.dns_class == DnsClass::kANY ||
618 record.dns_class() == callback_info.dns_class)) {
619 callback_info.callback->OnRecordChanged(record, event);
620 }
621 }
622 }
623
AddQuestion(const MdnsQuestion & question)624 void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
625 auto tracker = std::make_unique<MdnsQuestionTracker>(
626 std::move(question), sender_, task_runner_, now_function_, random_delay_,
627 config_);
628 MdnsQuestionTracker* ptr = tracker.get();
629 questions_.emplace(question.name(), std::move(tracker));
630
631 // Let all records associated with this question know that there is a new
632 // query that can be used for their refresh.
633 std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
634 records_.Find(question.name(), question.dns_type(), question.dns_class());
635 for (const MdnsRecordTracker& tracker : trackers) {
636 // NOTE: When the pointed to object is deleted, its dtor removes itself
637 // from all associated records.
638 ptr->AddAssociatedRecord(&tracker);
639 }
640 }
641
AddRecord(const MdnsRecord & record,DnsType type)642 void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
643 // Add the new record.
644 const auto& tracker = records_.StartTracking(record, type);
645
646 // Let all questions associated with this record know that there is a new
647 // record that answers them (for known answer suppression).
648 auto query_it = questions_.equal_range(record.name());
649 for (auto entry = query_it.first; entry != query_it.second; ++entry) {
650 const MdnsQuestion& query = entry->second->question();
651 const bool is_relevant_type =
652 type == DnsType::kANY || type == query.dns_type();
653 const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
654 record.dns_class() == query.dns_class();
655 if (is_relevant_type && is_relevant_class) {
656 // NOTE: When the pointed to object is deleted, its dtor removes itself
657 // from all associated queries.
658 entry->second->AddAssociatedRecord(&tracker);
659 }
660 }
661 }
662
663 } // namespace discovery
664 } // namespace openscreen
665