1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <memory>
6 #include <utility>
7 #include <vector>
8
9 #include "base/bind.h"
10 #include "base/location.h"
11 #include "base/macros.h"
12 #include "base/memory/ptr_util.h"
13 #include "base/memory/ref_counted.h"
14 #include "base/run_loop.h"
15 #include "base/single_thread_task_runner.h"
16 #include "base/test/simple_test_clock.h"
17 #include "base/threading/thread_task_runner_handle.h"
18 #include "base/time/clock.h"
19 #include "base/time/default_clock.h"
20 #include "base/timer/mock_timer.h"
21 #include "base/timer/timer.h"
22 #include "net/base/address_family.h"
23 #include "net/base/completion_repeating_callback.h"
24 #include "net/base/ip_address.h"
25 #include "net/base/rand_callback.h"
26 #include "net/base/test_completion_callback.h"
27 #include "net/dns/mdns_client_impl.h"
28 #include "net/dns/mock_mdns_socket_factory.h"
29 #include "net/dns/record_rdata.h"
30 #include "net/log/net_log.h"
31 #include "net/socket/udp_client_socket.h"
32 #include "net/test/gtest_util.h"
33 #include "net/test/test_with_task_environment.h"
34 #include "testing/gmock/include/gmock/gmock.h"
35 #include "testing/gtest/include/gtest/gtest.h"
36
37 using ::testing::_;
38 using ::testing::Assign;
39 using ::testing::AtMost;
40 using ::testing::Exactly;
41 using ::testing::IgnoreResult;
42 using ::testing::Invoke;
43 using ::testing::InvokeWithoutArgs;
44 using ::testing::NiceMock;
45 using ::testing::Return;
46 using ::testing::SaveArg;
47 using ::testing::StrictMock;
48
49 namespace net {
50
51 namespace {
52
53 const uint8_t kSamplePacket1[] = {
54 // Header
55 0x00, 0x00, // ID is zeroed out
56 0x81, 0x80, // Standard query response, RA, no error
57 0x00, 0x00, // No questions (for simplicity)
58 0x00, 0x02, // 2 RRs (answers)
59 0x00, 0x00, // 0 authority RRs
60 0x00, 0x00, // 0 additional RRs
61
62 // Answer 1
63 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
64 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
65 0x00, 0x01, // CLASS is IN.
66 0x00, 0x00, // TTL (4 bytes) is 1 second;
67 0x00, 0x01, 0x00, 0x08, // RDLENGTH is 8 bytes.
68 0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
69
70 // Answer 2
71 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', 0xc0,
72 0x14, // Pointer to "._tcp.local"
73 0x00, 0x0c, // TYPE is PTR.
74 0x00, 0x01, // CLASS is IN.
75 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
76 0x24, 0x75, 0x00, 0x08, // RDLENGTH is 8 bytes.
77 0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x32};
78
79 const uint8_t kCorruptedPacketBadQuestion[] = {
80 // Header
81 0x00, 0x00, // ID is zeroed out
82 0x81, 0x80, // Standard query response, RA, no error
83 0x00, 0x01, // One question
84 0x00, 0x02, // 2 RRs (answers)
85 0x00, 0x00, // 0 authority RRs
86 0x00, 0x00, // 0 additional RRs
87
88 // Question is corrupted and cannot be read.
89 0x99, 'h', 'e', 'l', 'l', 'o', 0x00, 0x00, 0x00, 0x00, 0x00,
90
91 // Answer 1
92 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
93 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
94 0x00, 0x01, // CLASS is IN.
95 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
96 0x24, 0x74, 0x00, 0x99, // RDLENGTH is impossible
97 0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
98
99 // Answer 2
100 0x08, '_', 'p', 'r', // Useless trailing data.
101 };
102
103 const uint8_t kCorruptedPacketUnsalvagable[] = {
104 // Header
105 0x00, 0x00, // ID is zeroed out
106 0x81, 0x80, // Standard query response, RA, no error
107 0x00, 0x00, // No questions (for simplicity)
108 0x00, 0x02, // 2 RRs (answers)
109 0x00, 0x00, // 0 authority RRs
110 0x00, 0x00, // 0 additional RRs
111
112 // Answer 1
113 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
114 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
115 0x00, 0x01, // CLASS is IN.
116 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
117 0x24, 0x74, 0x00, 0x99, // RDLENGTH is impossible
118 0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
119
120 // Answer 2
121 0x08, '_', 'p', 'r', // Useless trailing data.
122 };
123
124 const uint8_t kCorruptedPacketDoubleRecord[] = {
125 // Header
126 0x00, 0x00, // ID is zeroed out
127 0x81, 0x80, // Standard query response, RA, no error
128 0x00, 0x00, // No questions (for simplicity)
129 0x00, 0x02, // 2 RRs (answers)
130 0x00, 0x00, // 0 authority RRs
131 0x00, 0x00, // 0 additional RRs
132
133 // Answer 1
134 0x06, 'p', 'r', 'i', 'v', 'e', 't', 0x05, 'l', 'o', 'c', 'a', 'l', 0x00,
135 0x00, 0x01, // TYPE is A.
136 0x00, 0x01, // CLASS is IN.
137 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
138 0x24, 0x74, 0x00, 0x04, // RDLENGTH is 4
139 0x05, 0x03, 0xc0, 0x0c,
140
141 // Answer 2 -- Same key
142 0x06, 'p', 'r', 'i', 'v', 'e', 't', 0x05, 'l', 'o', 'c', 'a', 'l', 0x00,
143 0x00, 0x01, // TYPE is A.
144 0x00, 0x01, // CLASS is IN.
145 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
146 0x24, 0x74, 0x00, 0x04, // RDLENGTH is 4
147 0x02, 0x03, 0x04, 0x05,
148 };
149
150 const uint8_t kCorruptedPacketSalvagable[] = {
151 // Header
152 0x00, 0x00, // ID is zeroed out
153 0x81, 0x80, // Standard query response, RA, no error
154 0x00, 0x00, // No questions (for simplicity)
155 0x00, 0x02, // 2 RRs (answers)
156 0x00, 0x00, // 0 authority RRs
157 0x00, 0x00, // 0 additional RRs
158
159 // Answer 1
160 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
161 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
162 0x00, 0x01, // CLASS is IN.
163 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
164 0x24, 0x74, 0x00, 0x08, // RDLENGTH is 8 bytes.
165 0x99, 'h', 'e', 'l', 'l', 'o', // Bad RDATA format.
166 0xc0, 0x0c,
167
168 // Answer 2
169 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', 0xc0,
170 0x14, // Pointer to "._tcp.local"
171 0x00, 0x0c, // TYPE is PTR.
172 0x00, 0x01, // CLASS is IN.
173 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
174 0x24, 0x75, 0x00, 0x08, // RDLENGTH is 8 bytes.
175 0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x32};
176
177 const uint8_t kSamplePacket2[] = {
178 // Header
179 0x00, 0x00, // ID is zeroed out
180 0x81, 0x80, // Standard query response, RA, no error
181 0x00, 0x00, // No questions (for simplicity)
182 0x00, 0x02, // 2 RRs (answers)
183 0x00, 0x00, // 0 authority RRs
184 0x00, 0x00, // 0 additional RRs
185
186 // Answer 1
187 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
188 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
189 0x00, 0x01, // CLASS is IN.
190 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
191 0x24, 0x74, 0x00, 0x08, // RDLENGTH is 8 bytes.
192 0x05, 'z', 'z', 'z', 'z', 'z', 0xc0, 0x0c,
193
194 // Answer 2
195 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', 0xc0,
196 0x14, // Pointer to "._tcp.local"
197 0x00, 0x0c, // TYPE is PTR.
198 0x00, 0x01, // CLASS is IN.
199 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
200 0x24, 0x74, 0x00, 0x08, // RDLENGTH is 8 bytes.
201 0x05, 'z', 'z', 'z', 'z', 'z', 0xc0, 0x32};
202
203 const uint8_t kSamplePacket3[] = {
204 // Header
205 0x00, 0x00, // ID is zeroed out
206 0x81, 0x80, // Standard query response, RA, no error
207 0x00, 0x00, // No questions (for simplicity)
208 0x00, 0x02, // 2 RRs (answers)
209 0x00, 0x00, // 0 authority RRs
210 0x00, 0x00, // 0 additional RRs
211
212 // Answer 1
213 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', //
214 0x04, '_', 't', 'c', 'p', //
215 0x05, 'l', 'o', 'c', 'a', 'l', //
216 0x00, 0x00, 0x0c, // TYPE is PTR.
217 0x00, 0x01, // CLASS is IN.
218 0x00, 0x00, // TTL (4 bytes) is 1 second;
219 0x00, 0x01, //
220 0x00, 0x08, // RDLENGTH is 8 bytes.
221 0x05, 'h', 'e', 'l', 'l', 'o', //
222 0xc0, 0x0c, //
223
224 // Answer 2
225 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', //
226 0xc0, 0x14, // Pointer to "._tcp.local"
227 0x00, 0x0c, // TYPE is PTR.
228 0x00, 0x01, // CLASS is IN.
229 0x00, 0x00, // TTL (4 bytes) is 3 seconds.
230 0x00, 0x03, //
231 0x00, 0x08, // RDLENGTH is 8 bytes.
232 0x05, 'h', 'e', 'l', 'l', 'o', //
233 0xc0, 0x32};
234
235 const uint8_t kQueryPacketPrivet[] = {
236 // Header
237 0x00, 0x00, // ID is zeroed out
238 0x00, 0x00, // No flags.
239 0x00, 0x01, // One question.
240 0x00, 0x00, // 0 RRs (answers)
241 0x00, 0x00, // 0 authority RRs
242 0x00, 0x00, // 0 additional RRs
243
244 // Question
245 // This part is echoed back from the respective query.
246 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
247 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
248 0x00, 0x01, // CLASS is IN.
249 };
250
251 const uint8_t kQueryPacketPrivetA[] = {
252 // Header
253 0x00, 0x00, // ID is zeroed out
254 0x00, 0x00, // No flags.
255 0x00, 0x01, // One question.
256 0x00, 0x00, // 0 RRs (answers)
257 0x00, 0x00, // 0 authority RRs
258 0x00, 0x00, // 0 additional RRs
259
260 // Question
261 // This part is echoed back from the respective query.
262 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
263 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x01, // TYPE is A.
264 0x00, 0x01, // CLASS is IN.
265 };
266
267 const uint8_t kSamplePacketAdditionalOnly[] = {
268 // Header
269 0x00, 0x00, // ID is zeroed out
270 0x81, 0x80, // Standard query response, RA, no error
271 0x00, 0x00, // No questions (for simplicity)
272 0x00, 0x00, // 2 RRs (answers)
273 0x00, 0x00, // 0 authority RRs
274 0x00, 0x01, // 0 additional RRs
275
276 // Answer 1
277 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
278 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
279 0x00, 0x01, // CLASS is IN.
280 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
281 0x24, 0x74, 0x00, 0x08, // RDLENGTH is 8 bytes.
282 0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
283 };
284
285 const uint8_t kSamplePacketNsec[] = {
286 // Header
287 0x00, 0x00, // ID is zeroed out
288 0x81, 0x80, // Standard query response, RA, no error
289 0x00, 0x00, // No questions (for simplicity)
290 0x00, 0x01, // 1 RR (answers)
291 0x00, 0x00, // 0 authority RRs
292 0x00, 0x00, // 0 additional RRs
293
294 // Answer 1
295 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
296 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x2f, // TYPE is NSEC.
297 0x00, 0x01, // CLASS is IN.
298 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
299 0x24, 0x74, 0x00, 0x06, // RDLENGTH is 6 bytes.
300 0xc0, 0x0c, 0x00, 0x02, 0x00, 0x08 // Only A record present
301 };
302
303 const uint8_t kSamplePacketAPrivet[] = {
304 // Header
305 0x00, 0x00, // ID is zeroed out
306 0x81, 0x80, // Standard query response, RA, no error
307 0x00, 0x00, // No questions (for simplicity)
308 0x00, 0x01, // 1 RR (answers)
309 0x00, 0x00, // 0 authority RRs
310 0x00, 0x00, // 0 additional RRs
311
312 // Answer 1
313 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
314 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x01, // TYPE is A.
315 0x00, 0x01, // CLASS is IN.
316 0x00, 0x00, // TTL (4 bytes) is 5 seconds
317 0x00, 0x05, 0x00, 0x04, // RDLENGTH is 4 bytes.
318 0xc0, 0x0c, 0x00, 0x02,
319 };
320
321 const uint8_t kSamplePacketGoodbye[] = {
322 // Header
323 0x00, 0x00, // ID is zeroed out
324 0x81, 0x80, // Standard query response, RA, no error
325 0x00, 0x00, // No questions (for simplicity)
326 0x00, 0x01, // 2 RRs (answers)
327 0x00, 0x00, // 0 authority RRs
328 0x00, 0x00, // 0 additional RRs
329
330 // Answer 1
331 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
332 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c, // TYPE is PTR.
333 0x00, 0x01, // CLASS is IN.
334 0x00, 0x00, // TTL (4 bytes) is zero;
335 0x00, 0x00, 0x00, 0x08, // RDLENGTH is 8 bytes.
336 0x05, 'z', 'z', 'z', 'z', 'z', 0xc0, 0x0c,
337 };
338
MakeString(const uint8_t * data,unsigned size)339 std::string MakeString(const uint8_t* data, unsigned size) {
340 return std::string(reinterpret_cast<const char*>(data), size);
341 }
342
343 class PtrRecordCopyContainer {
344 public:
345 PtrRecordCopyContainer() = default;
346 ~PtrRecordCopyContainer() = default;
347
is_set() const348 bool is_set() const { return set_; }
349
SaveWithDummyArg(int unused,const RecordParsed * value)350 void SaveWithDummyArg(int unused, const RecordParsed* value) {
351 Save(value);
352 }
353
Save(const RecordParsed * value)354 void Save(const RecordParsed* value) {
355 set_ = true;
356 name_ = value->name();
357 ptrdomain_ = value->rdata<PtrRecordRdata>()->ptrdomain();
358 ttl_ = value->ttl();
359 }
360
IsRecordWith(const std::string & name,const std::string & ptrdomain)361 bool IsRecordWith(const std::string& name, const std::string& ptrdomain) {
362 return set_ && name_ == name && ptrdomain_ == ptrdomain;
363 }
364
name()365 const std::string& name() { return name_; }
ptrdomain()366 const std::string& ptrdomain() { return ptrdomain_; }
ttl()367 int ttl() { return ttl_; }
368
369 private:
370 bool set_;
371 std::string name_;
372 std::string ptrdomain_;
373 int ttl_;
374 };
375
376 class MockClock : public base::Clock {
377 public:
378 MockClock() = default;
379 ~MockClock() override = default;
380
381 MOCK_CONST_METHOD0(Now, base::Time());
382
383 private:
384 DISALLOW_COPY_AND_ASSIGN(MockClock);
385 };
386
387 class MockTimer : public base::MockOneShotTimer {
388 public:
MockTimer()389 MockTimer() {}
390 ~MockTimer() override = default;
391
Start(const base::Location & posted_from,base::TimeDelta delay,base::OnceClosure user_task)392 void Start(const base::Location& posted_from,
393 base::TimeDelta delay,
394 base::OnceClosure user_task) override {
395 StartObserver(posted_from, delay);
396 base::MockOneShotTimer::Start(posted_from, delay, std::move(user_task));
397 }
398
399 // StartObserver is invoked when MockTimer::Start() is called.
400 // Does not replace the behavior of MockTimer::Start().
401 MOCK_METHOD2(StartObserver,
402 void(const base::Location& posted_from, base::TimeDelta delay));
403
404 private:
405 DISALLOW_COPY_AND_ASSIGN(MockTimer);
406 };
407
408 } // namespace
409
410 class MDnsTest : public TestWithTaskEnvironment {
411 public:
412 void SetUp() override;
413 void DeleteTransaction();
414 void DeleteBothListeners();
415 void RunFor(base::TimeDelta time_period);
416 void Stop();
417
418 MOCK_METHOD2(MockableRecordCallback, void(MDnsTransaction::Result result,
419 const RecordParsed* record));
420
421 MOCK_METHOD2(MockableRecordCallback2, void(MDnsTransaction::Result result,
422 const RecordParsed* record));
423
424 protected:
425 void ExpectPacket(const uint8_t* packet, unsigned size);
426 void SimulatePacketReceive(const uint8_t* packet, unsigned size);
427
428 std::unique_ptr<MDnsClientImpl> test_client_;
429 IPEndPoint mdns_ipv4_endpoint_;
430 StrictMock<MockMDnsSocketFactory> socket_factory_;
431
432 // Transactions and listeners that can be deleted by class methods for
433 // reentrancy tests.
434 std::unique_ptr<MDnsTransaction> transaction_;
435 std::unique_ptr<MDnsListener> listener1_;
436 std::unique_ptr<MDnsListener> listener2_;
437 };
438
439 class MockListenerDelegate : public MDnsListener::Delegate {
440 public:
441 MOCK_METHOD2(OnRecordUpdate,
442 void(MDnsListener::UpdateType update,
443 const RecordParsed* records));
444 MOCK_METHOD2(OnNsecRecord, void(const std::string&, unsigned));
445 MOCK_METHOD0(OnCachePurged, void());
446 };
447
SetUp()448 void MDnsTest::SetUp() {
449 test_client_.reset(new MDnsClientImpl());
450 ASSERT_THAT(test_client_->StartListening(&socket_factory_), test::IsOk());
451 }
452
SimulatePacketReceive(const uint8_t * packet,unsigned size)453 void MDnsTest::SimulatePacketReceive(const uint8_t* packet, unsigned size) {
454 socket_factory_.SimulateReceive(packet, size);
455 }
456
ExpectPacket(const uint8_t * packet,unsigned size)457 void MDnsTest::ExpectPacket(const uint8_t* packet, unsigned size) {
458 EXPECT_CALL(socket_factory_, OnSendTo(MakeString(packet, size)))
459 .Times(2);
460 }
461
DeleteTransaction()462 void MDnsTest::DeleteTransaction() {
463 transaction_.reset();
464 }
465
DeleteBothListeners()466 void MDnsTest::DeleteBothListeners() {
467 listener1_.reset();
468 listener2_.reset();
469 }
470
RunFor(base::TimeDelta time_period)471 void MDnsTest::RunFor(base::TimeDelta time_period) {
472 base::CancelableCallback<void()> callback(base::Bind(&MDnsTest::Stop,
473 base::Unretained(this)));
474 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
475 FROM_HERE, callback.callback(), time_period);
476
477 base::RunLoop().Run();
478 callback.Cancel();
479 }
480
Stop()481 void MDnsTest::Stop() {
482 base::RunLoop::QuitCurrentWhenIdleDeprecated();
483 }
484
TEST_F(MDnsTest,PassiveListeners)485 TEST_F(MDnsTest, PassiveListeners) {
486 StrictMock<MockListenerDelegate> delegate_privet;
487 StrictMock<MockListenerDelegate> delegate_printer;
488
489 PtrRecordCopyContainer record_privet;
490 PtrRecordCopyContainer record_printer;
491
492 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
493 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
494 std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
495 dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
496
497 ASSERT_TRUE(listener_privet->Start());
498 ASSERT_TRUE(listener_printer->Start());
499
500 // Send the same packet twice to ensure no records are double-counted.
501
502 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
503 .Times(Exactly(1))
504 .WillOnce(Invoke(
505 &record_privet,
506 &PtrRecordCopyContainer::SaveWithDummyArg));
507
508 EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
509 .Times(Exactly(1))
510 .WillOnce(Invoke(
511 &record_printer,
512 &PtrRecordCopyContainer::SaveWithDummyArg));
513
514
515 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
516 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
517
518 EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
519 "hello._privet._tcp.local"));
520
521 EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
522 "hello._printer._tcp.local"));
523
524 listener_privet.reset();
525 listener_printer.reset();
526 }
527
TEST_F(MDnsTest,PassiveListenersCacheCleanup)528 TEST_F(MDnsTest, PassiveListenersCacheCleanup) {
529 StrictMock<MockListenerDelegate> delegate_privet;
530
531 PtrRecordCopyContainer record_privet;
532 PtrRecordCopyContainer record_privet2;
533
534 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
535 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
536
537 ASSERT_TRUE(listener_privet->Start());
538
539 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
540 .Times(Exactly(1))
541 .WillOnce(Invoke(
542 &record_privet,
543 &PtrRecordCopyContainer::SaveWithDummyArg));
544
545 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
546
547 EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
548 "hello._privet._tcp.local"));
549
550 // Expect record is removed when its TTL expires.
551 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
552 .Times(Exactly(1))
553 .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::Stop),
554 Invoke(&record_privet2,
555 &PtrRecordCopyContainer::SaveWithDummyArg)));
556
557 RunFor(base::TimeDelta::FromSeconds(record_privet.ttl() + 1));
558
559 EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
560 "hello._privet._tcp.local"));
561 }
562
563 // Ensure that the cleanup task scheduler won't schedule cleanup tasks in the
564 // past if the system clock creeps past the expiration time while in the
565 // cleanup dispatcher.
TEST_F(MDnsTest,CacheCleanupWithShortTTL)566 TEST_F(MDnsTest, CacheCleanupWithShortTTL) {
567 // Use a nonzero starting time as a base.
568 base::Time start_time = base::Time() + base::TimeDelta::FromSeconds(1);
569
570 MockClock clock;
571 MockTimer* timer = new MockTimer;
572
573 test_client_.reset(new MDnsClientImpl(&clock, base::WrapUnique(timer)));
574 ASSERT_THAT(test_client_->StartListening(&socket_factory_), test::IsOk());
575
576 EXPECT_CALL(*timer, StartObserver(_, _)).Times(1);
577 EXPECT_CALL(clock, Now())
578 .Times(3)
579 .WillRepeatedly(Return(start_time))
580 .RetiresOnSaturation();
581
582 // Receive two records with different TTL values.
583 // TTL(privet)=1.0s
584 // TTL(printer)=3.0s
585 StrictMock<MockListenerDelegate> delegate_privet;
586 StrictMock<MockListenerDelegate> delegate_printer;
587
588 PtrRecordCopyContainer record_privet;
589 PtrRecordCopyContainer record_printer;
590
591 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
592 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
593 std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
594 dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
595
596 ASSERT_TRUE(listener_privet->Start());
597 ASSERT_TRUE(listener_printer->Start());
598
599 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
600 .Times(Exactly(1));
601 EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
602 .Times(Exactly(1));
603
604 SimulatePacketReceive(kSamplePacket3, sizeof(kSamplePacket3));
605
606 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
607 .Times(Exactly(1));
608
609 // Set the clock to 2.0s, which should clean up the 'privet' record, but not
610 // the printer. The mock clock will change Now() mid-execution from 2s to 4s.
611 // Note: expectations are FILO-ordered -- t+2 seconds is returned, then t+4.
612 EXPECT_CALL(clock, Now())
613 .WillOnce(Return(start_time + base::TimeDelta::FromSeconds(4)))
614 .RetiresOnSaturation();
615 EXPECT_CALL(clock, Now())
616 .WillOnce(Return(start_time + base::TimeDelta::FromSeconds(2)))
617 .RetiresOnSaturation();
618
619 EXPECT_CALL(*timer, StartObserver(_, base::TimeDelta()));
620
621 timer->Fire();
622 }
623
TEST_F(MDnsTest,StopListening)624 TEST_F(MDnsTest, StopListening) {
625 ASSERT_TRUE(test_client_->IsListening());
626
627 test_client_->StopListening();
628 EXPECT_FALSE(test_client_->IsListening());
629 }
630
TEST_F(MDnsTest,StopListening_CacheCleanupScheduled)631 TEST_F(MDnsTest, StopListening_CacheCleanupScheduled) {
632 base::SimpleTestClock clock;
633 // Use a nonzero starting time as a base.
634 clock.SetNow(base::Time() + base::TimeDelta::FromSeconds(1));
635 auto cleanup_timer = std::make_unique<base::MockOneShotTimer>();
636 base::OneShotTimer* cleanup_timer_ptr = cleanup_timer.get();
637
638 test_client_ =
639 std::make_unique<MDnsClientImpl>(&clock, std::move(cleanup_timer));
640 ASSERT_THAT(test_client_->StartListening(&socket_factory_), test::IsOk());
641 ASSERT_TRUE(test_client_->IsListening());
642
643 // Receive one record (privet) with TTL=1s to schedule cleanup.
644 SimulatePacketReceive(kSamplePacket3, sizeof(kSamplePacket3));
645 ASSERT_TRUE(cleanup_timer_ptr->IsRunning());
646
647 test_client_->StopListening();
648 EXPECT_FALSE(test_client_->IsListening());
649
650 // Expect cleanup unscheduled.
651 EXPECT_FALSE(cleanup_timer_ptr->IsRunning());
652 }
653
TEST_F(MDnsTest,MalformedPacket)654 TEST_F(MDnsTest, MalformedPacket) {
655 StrictMock<MockListenerDelegate> delegate_printer;
656
657 PtrRecordCopyContainer record_printer;
658
659 std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
660 dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
661
662 ASSERT_TRUE(listener_printer->Start());
663
664 EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
665 .Times(Exactly(1))
666 .WillOnce(Invoke(
667 &record_printer,
668 &PtrRecordCopyContainer::SaveWithDummyArg));
669
670 // First, send unsalvagable packet to ensure we can deal with it.
671 SimulatePacketReceive(kCorruptedPacketUnsalvagable,
672 sizeof(kCorruptedPacketUnsalvagable));
673
674 // Regression test: send a packet where the question cannot be read.
675 SimulatePacketReceive(kCorruptedPacketBadQuestion,
676 sizeof(kCorruptedPacketBadQuestion));
677
678 // Then send salvagable packet to ensure we can extract useful records.
679 SimulatePacketReceive(kCorruptedPacketSalvagable,
680 sizeof(kCorruptedPacketSalvagable));
681
682 EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
683 "hello._printer._tcp.local"));
684 }
685
TEST_F(MDnsTest,TransactionWithEmptyCache)686 TEST_F(MDnsTest, TransactionWithEmptyCache) {
687 ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
688
689 std::unique_ptr<MDnsTransaction> transaction_privet =
690 test_client_->CreateTransaction(
691 dns_protocol::kTypePTR, "_privet._tcp.local",
692 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
693 MDnsTransaction::SINGLE_RESULT,
694 base::BindRepeating(&MDnsTest::MockableRecordCallback,
695 base::Unretained(this)));
696
697 ASSERT_TRUE(transaction_privet->Start());
698
699 PtrRecordCopyContainer record_privet;
700
701 EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
702 .Times(Exactly(1))
703 .WillOnce(Invoke(&record_privet,
704 &PtrRecordCopyContainer::SaveWithDummyArg));
705
706 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
707
708 EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
709 "hello._privet._tcp.local"));
710 }
711
TEST_F(MDnsTest,TransactionCacheOnlyNoResult)712 TEST_F(MDnsTest, TransactionCacheOnlyNoResult) {
713 std::unique_ptr<MDnsTransaction> transaction_privet =
714 test_client_->CreateTransaction(
715 dns_protocol::kTypePTR, "_privet._tcp.local",
716 MDnsTransaction::QUERY_CACHE | MDnsTransaction::SINGLE_RESULT,
717 base::BindRepeating(&MDnsTest::MockableRecordCallback,
718 base::Unretained(this)));
719
720 EXPECT_CALL(*this,
721 MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, _))
722 .Times(Exactly(1));
723
724 ASSERT_TRUE(transaction_privet->Start());
725 }
726
TEST_F(MDnsTest,TransactionWithCache)727 TEST_F(MDnsTest, TransactionWithCache) {
728 // Listener to force the client to listen
729 StrictMock<MockListenerDelegate> delegate_irrelevant;
730 std::unique_ptr<MDnsListener> listener_irrelevant =
731 test_client_->CreateListener(dns_protocol::kTypeA,
732 "codereview.chromium.local",
733 &delegate_irrelevant);
734
735 ASSERT_TRUE(listener_irrelevant->Start());
736
737 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
738
739
740 PtrRecordCopyContainer record_privet;
741
742 EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
743 .WillOnce(Invoke(&record_privet,
744 &PtrRecordCopyContainer::SaveWithDummyArg));
745
746 std::unique_ptr<MDnsTransaction> transaction_privet =
747 test_client_->CreateTransaction(
748 dns_protocol::kTypePTR, "_privet._tcp.local",
749 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
750 MDnsTransaction::SINGLE_RESULT,
751 base::BindRepeating(&MDnsTest::MockableRecordCallback,
752 base::Unretained(this)));
753
754 ASSERT_TRUE(transaction_privet->Start());
755
756 EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
757 "hello._privet._tcp.local"));
758 }
759
TEST_F(MDnsTest,AdditionalRecords)760 TEST_F(MDnsTest, AdditionalRecords) {
761 StrictMock<MockListenerDelegate> delegate_privet;
762
763 PtrRecordCopyContainer record_privet;
764
765 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
766 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
767
768 ASSERT_TRUE(listener_privet->Start());
769
770 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
771 .Times(Exactly(1))
772 .WillOnce(Invoke(
773 &record_privet,
774 &PtrRecordCopyContainer::SaveWithDummyArg));
775
776 SimulatePacketReceive(kSamplePacketAdditionalOnly,
777 sizeof(kSamplePacketAdditionalOnly));
778
779 EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
780 "hello._privet._tcp.local"));
781 }
782
TEST_F(MDnsTest,TransactionTimeout)783 TEST_F(MDnsTest, TransactionTimeout) {
784 ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
785
786 std::unique_ptr<MDnsTransaction> transaction_privet =
787 test_client_->CreateTransaction(
788 dns_protocol::kTypePTR, "_privet._tcp.local",
789 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
790 MDnsTransaction::SINGLE_RESULT,
791 base::BindRepeating(&MDnsTest::MockableRecordCallback,
792 base::Unretained(this)));
793
794 ASSERT_TRUE(transaction_privet->Start());
795
796 EXPECT_CALL(*this,
797 MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL))
798 .Times(Exactly(1))
799 .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
800
801 RunFor(base::TimeDelta::FromSeconds(4));
802 }
803
TEST_F(MDnsTest,TransactionMultipleRecords)804 TEST_F(MDnsTest, TransactionMultipleRecords) {
805 ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
806
807 std::unique_ptr<MDnsTransaction> transaction_privet =
808 test_client_->CreateTransaction(
809 dns_protocol::kTypePTR, "_privet._tcp.local",
810 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE,
811 base::BindRepeating(&MDnsTest::MockableRecordCallback,
812 base::Unretained(this)));
813
814 ASSERT_TRUE(transaction_privet->Start());
815
816 PtrRecordCopyContainer record_privet;
817 PtrRecordCopyContainer record_privet2;
818
819 EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
820 .Times(Exactly(2))
821 .WillOnce(Invoke(&record_privet,
822 &PtrRecordCopyContainer::SaveWithDummyArg))
823 .WillOnce(Invoke(&record_privet2,
824 &PtrRecordCopyContainer::SaveWithDummyArg));
825
826 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
827 SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
828
829 EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
830 "hello._privet._tcp.local"));
831
832 EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
833 "zzzzz._privet._tcp.local"));
834
835 EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_DONE, NULL))
836 .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
837
838 RunFor(base::TimeDelta::FromSeconds(4));
839 }
840
TEST_F(MDnsTest,TransactionReentrantDelete)841 TEST_F(MDnsTest, TransactionReentrantDelete) {
842 ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
843
844 transaction_ = test_client_->CreateTransaction(
845 dns_protocol::kTypePTR, "_privet._tcp.local",
846 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
847 MDnsTransaction::SINGLE_RESULT,
848 base::BindRepeating(&MDnsTest::MockableRecordCallback,
849 base::Unretained(this)));
850
851 ASSERT_TRUE(transaction_->Start());
852
853 EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS,
854 NULL))
855 .Times(Exactly(1))
856 .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction),
857 InvokeWithoutArgs(this, &MDnsTest::Stop)));
858
859 RunFor(base::TimeDelta::FromSeconds(4));
860
861 EXPECT_EQ(NULL, transaction_.get());
862 }
863
TEST_F(MDnsTest,TransactionReentrantDeleteFromCache)864 TEST_F(MDnsTest, TransactionReentrantDeleteFromCache) {
865 StrictMock<MockListenerDelegate> delegate_irrelevant;
866 std::unique_ptr<MDnsListener> listener_irrelevant =
867 test_client_->CreateListener(dns_protocol::kTypeA,
868 "codereview.chromium.local",
869 &delegate_irrelevant);
870 ASSERT_TRUE(listener_irrelevant->Start());
871
872 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
873
874 transaction_ = test_client_->CreateTransaction(
875 dns_protocol::kTypePTR, "_privet._tcp.local",
876 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE,
877 base::BindRepeating(&MDnsTest::MockableRecordCallback,
878 base::Unretained(this)));
879
880 EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
881 .Times(Exactly(1))
882 .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction));
883
884 ASSERT_TRUE(transaction_->Start());
885
886 EXPECT_EQ(NULL, transaction_.get());
887 }
888
TEST_F(MDnsTest,TransactionReentrantCacheLookupStart)889 TEST_F(MDnsTest, TransactionReentrantCacheLookupStart) {
890 ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
891
892 std::unique_ptr<MDnsTransaction> transaction1 =
893 test_client_->CreateTransaction(
894 dns_protocol::kTypePTR, "_privet._tcp.local",
895 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
896 MDnsTransaction::SINGLE_RESULT,
897 base::BindRepeating(&MDnsTest::MockableRecordCallback,
898 base::Unretained(this)));
899
900 std::unique_ptr<MDnsTransaction> transaction2 =
901 test_client_->CreateTransaction(
902 dns_protocol::kTypePTR, "_printer._tcp.local",
903 MDnsTransaction::QUERY_CACHE | MDnsTransaction::SINGLE_RESULT,
904 base::BindRepeating(&MDnsTest::MockableRecordCallback2,
905 base::Unretained(this)));
906
907 EXPECT_CALL(*this, MockableRecordCallback2(MDnsTransaction::RESULT_RECORD,
908 _))
909 .Times(Exactly(1));
910
911 EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD,
912 _))
913 .Times(Exactly(1))
914 .WillOnce(IgnoreResult(InvokeWithoutArgs(transaction2.get(),
915 &MDnsTransaction::Start)));
916
917 ASSERT_TRUE(transaction1->Start());
918
919 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
920 }
921
TEST_F(MDnsTest,GoodbyePacketNotification)922 TEST_F(MDnsTest, GoodbyePacketNotification) {
923 StrictMock<MockListenerDelegate> delegate_privet;
924
925 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
926 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
927 ASSERT_TRUE(listener_privet->Start());
928
929 SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye));
930
931 RunFor(base::TimeDelta::FromSeconds(2));
932 }
933
TEST_F(MDnsTest,GoodbyePacketRemoval)934 TEST_F(MDnsTest, GoodbyePacketRemoval) {
935 StrictMock<MockListenerDelegate> delegate_privet;
936
937 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
938 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
939 ASSERT_TRUE(listener_privet->Start());
940
941 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
942 .Times(Exactly(1));
943
944 SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
945
946 SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye));
947
948 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
949 .Times(Exactly(1));
950
951 RunFor(base::TimeDelta::FromSeconds(2));
952 }
953
954 // In order to reliably test reentrant listener deletes, we create two listeners
955 // and have each of them delete both, so we're guaranteed to try and deliver a
956 // callback to at least one deleted listener.
957
TEST_F(MDnsTest,ListenerReentrantDelete)958 TEST_F(MDnsTest, ListenerReentrantDelete) {
959 StrictMock<MockListenerDelegate> delegate_privet;
960
961 listener1_ = test_client_->CreateListener(
962 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
963
964 listener2_ = test_client_->CreateListener(
965 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
966
967 ASSERT_TRUE(listener1_->Start());
968
969 ASSERT_TRUE(listener2_->Start());
970
971 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
972 .Times(Exactly(1))
973 .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteBothListeners));
974
975 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
976
977 EXPECT_EQ(NULL, listener1_.get());
978 EXPECT_EQ(NULL, listener2_.get());
979 }
980
ACTION_P(SaveIPAddress,ip_container)981 ACTION_P(SaveIPAddress, ip_container) {
982 ::testing::StaticAssertTypeEq<const RecordParsed*, arg1_type>();
983 ::testing::StaticAssertTypeEq<IPAddress*, ip_container_type>();
984
985 *ip_container = arg1->template rdata<ARecordRdata>()->address();
986 }
987
TEST_F(MDnsTest,DoubleRecordDisagreeing)988 TEST_F(MDnsTest, DoubleRecordDisagreeing) {
989 IPAddress address;
990 StrictMock<MockListenerDelegate> delegate_privet;
991
992 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
993 dns_protocol::kTypeA, "privet.local", &delegate_privet);
994
995 ASSERT_TRUE(listener_privet->Start());
996
997 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
998 .Times(Exactly(1))
999 .WillOnce(SaveIPAddress(&address));
1000
1001 SimulatePacketReceive(kCorruptedPacketDoubleRecord,
1002 sizeof(kCorruptedPacketDoubleRecord));
1003
1004 EXPECT_EQ("2.3.4.5", address.ToString());
1005 }
1006
TEST_F(MDnsTest,NsecWithListener)1007 TEST_F(MDnsTest, NsecWithListener) {
1008 StrictMock<MockListenerDelegate> delegate_privet;
1009 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1010 dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
1011
1012 // Test to make sure nsec callback is NOT called for PTR
1013 // (which is marked as existing).
1014 StrictMock<MockListenerDelegate> delegate_privet2;
1015 std::unique_ptr<MDnsListener> listener_privet2 = test_client_->CreateListener(
1016 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet2);
1017
1018 ASSERT_TRUE(listener_privet->Start());
1019
1020 EXPECT_CALL(delegate_privet,
1021 OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA));
1022
1023 SimulatePacketReceive(kSamplePacketNsec,
1024 sizeof(kSamplePacketNsec));
1025 }
1026
TEST_F(MDnsTest,NsecWithTransactionFromNetwork)1027 TEST_F(MDnsTest, NsecWithTransactionFromNetwork) {
1028 std::unique_ptr<MDnsTransaction> transaction_privet =
1029 test_client_->CreateTransaction(
1030 dns_protocol::kTypeA, "_privet._tcp.local",
1031 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
1032 MDnsTransaction::SINGLE_RESULT,
1033 base::BindRepeating(&MDnsTest::MockableRecordCallback,
1034 base::Unretained(this)));
1035
1036 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
1037
1038 ASSERT_TRUE(transaction_privet->Start());
1039
1040 EXPECT_CALL(*this,
1041 MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL));
1042
1043 SimulatePacketReceive(kSamplePacketNsec,
1044 sizeof(kSamplePacketNsec));
1045 }
1046
TEST_F(MDnsTest,NsecWithTransactionFromCache)1047 TEST_F(MDnsTest, NsecWithTransactionFromCache) {
1048 // Force mDNS to listen.
1049 StrictMock<MockListenerDelegate> delegate_irrelevant;
1050 std::unique_ptr<MDnsListener> listener_irrelevant =
1051 test_client_->CreateListener(dns_protocol::kTypePTR, "_privet._tcp.local",
1052 &delegate_irrelevant);
1053 listener_irrelevant->Start();
1054
1055 SimulatePacketReceive(kSamplePacketNsec,
1056 sizeof(kSamplePacketNsec));
1057
1058 EXPECT_CALL(*this,
1059 MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL));
1060
1061 std::unique_ptr<MDnsTransaction> transaction_privet_a =
1062 test_client_->CreateTransaction(
1063 dns_protocol::kTypeA, "_privet._tcp.local",
1064 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
1065 MDnsTransaction::SINGLE_RESULT,
1066 base::BindRepeating(&MDnsTest::MockableRecordCallback,
1067 base::Unretained(this)));
1068
1069 ASSERT_TRUE(transaction_privet_a->Start());
1070
1071 // Test that a PTR transaction does NOT consider the same NSEC record to be a
1072 // valid answer to the query
1073
1074 std::unique_ptr<MDnsTransaction> transaction_privet_ptr =
1075 test_client_->CreateTransaction(
1076 dns_protocol::kTypePTR, "_privet._tcp.local",
1077 MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
1078 MDnsTransaction::SINGLE_RESULT,
1079 base::BindRepeating(&MDnsTest::MockableRecordCallback,
1080 base::Unretained(this)));
1081
1082 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
1083
1084 ASSERT_TRUE(transaction_privet_ptr->Start());
1085 }
1086
TEST_F(MDnsTest,NsecConflictRemoval)1087 TEST_F(MDnsTest, NsecConflictRemoval) {
1088 StrictMock<MockListenerDelegate> delegate_privet;
1089 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1090 dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
1091
1092 ASSERT_TRUE(listener_privet->Start());
1093
1094 const RecordParsed* record1;
1095 const RecordParsed* record2;
1096
1097 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
1098 .WillOnce(SaveArg<1>(&record1));
1099
1100 SimulatePacketReceive(kSamplePacketAPrivet,
1101 sizeof(kSamplePacketAPrivet));
1102
1103 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
1104 .WillOnce(SaveArg<1>(&record2));
1105
1106 EXPECT_CALL(delegate_privet,
1107 OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA));
1108
1109 SimulatePacketReceive(kSamplePacketNsec,
1110 sizeof(kSamplePacketNsec));
1111
1112 EXPECT_EQ(record1, record2);
1113 }
1114
1115
TEST_F(MDnsTest,RefreshQuery)1116 TEST_F(MDnsTest, RefreshQuery) {
1117 StrictMock<MockListenerDelegate> delegate_privet;
1118 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1119 dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
1120
1121 listener_privet->SetActiveRefresh(true);
1122 ASSERT_TRUE(listener_privet->Start());
1123
1124 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _));
1125
1126 SimulatePacketReceive(kSamplePacketAPrivet,
1127 sizeof(kSamplePacketAPrivet));
1128
1129 // Expecting 2 calls (one for ipv4 and one for ipv6) for each of the 2
1130 // scheduled refresh queries.
1131 EXPECT_CALL(socket_factory_, OnSendTo(
1132 MakeString(kQueryPacketPrivetA, sizeof(kQueryPacketPrivetA))))
1133 .Times(4);
1134
1135 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _));
1136
1137 RunFor(base::TimeDelta::FromSeconds(6));
1138 }
1139
1140 // MDnsSocketFactory implementation that creates a single socket that will
1141 // always fail on RecvFrom. Passing this to MdnsClient is expected to result in
1142 // the client failing to start listening.
1143 class FailingSocketFactory : public MDnsSocketFactory {
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)1144 void CreateSockets(
1145 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override {
1146 auto socket =
1147 std::make_unique<MockMDnsDatagramServerSocket>(ADDRESS_FAMILY_IPV4);
1148 EXPECT_CALL(*socket, RecvFromInternal(_, _, _, _))
1149 .WillRepeatedly(Return(ERR_FAILED));
1150 sockets->push_back(std::move(socket));
1151 }
1152 };
1153
TEST_F(MDnsTest,StartListeningFailure)1154 TEST_F(MDnsTest, StartListeningFailure) {
1155 test_client_ = std::make_unique<MDnsClientImpl>();
1156 FailingSocketFactory socket_factory;
1157
1158 EXPECT_THAT(test_client_->StartListening(&socket_factory),
1159 test::IsError(ERR_FAILED));
1160 }
1161
1162 // Test that the cache is cleared when it gets filled to unreasonable sizes.
TEST_F(MDnsTest,ClearOverfilledCache)1163 TEST_F(MDnsTest, ClearOverfilledCache) {
1164 test_client_->core()->cache_for_testing()->set_entry_limit_for_testing(1);
1165
1166 StrictMock<MockListenerDelegate> delegate_privet;
1167 StrictMock<MockListenerDelegate> delegate_printer;
1168
1169 PtrRecordCopyContainer record_privet;
1170 PtrRecordCopyContainer record_printer;
1171
1172 std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1173 dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
1174 std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
1175 dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
1176
1177 ASSERT_TRUE(listener_privet->Start());
1178 ASSERT_TRUE(listener_printer->Start());
1179
1180 bool privet_added = false;
1181 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
1182 .Times(AtMost(1))
1183 .WillOnce(Assign(&privet_added, true));
1184 EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
1185 .WillRepeatedly(Assign(&privet_added, false));
1186
1187 bool printer_added = false;
1188 EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
1189 .Times(AtMost(1))
1190 .WillOnce(Assign(&printer_added, true));
1191 EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
1192 .WillRepeatedly(Assign(&printer_added, false));
1193
1194 // Fill past capacity and expect everything to eventually be removed.
1195 SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
1196 base::RunLoop().RunUntilIdle();
1197 EXPECT_FALSE(privet_added);
1198 EXPECT_FALSE(printer_added);
1199 }
1200
1201 // Note: These tests assume that the ipv4 socket will always be created first.
1202 // This is a simplifying assumption based on the way the code works now.
1203 class SimpleMockSocketFactory : public MDnsSocketFactory {
1204 public:
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)1205 void CreateSockets(
1206 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override {
1207 sockets->clear();
1208 sockets->swap(sockets_);
1209 }
1210
PushSocket(std::unique_ptr<DatagramServerSocket> socket)1211 void PushSocket(std::unique_ptr<DatagramServerSocket> socket) {
1212 sockets_.push_back(std::move(socket));
1213 }
1214
1215 private:
1216 std::vector<std::unique_ptr<DatagramServerSocket>> sockets_;
1217 };
1218
1219 class MockMDnsConnectionDelegate : public MDnsConnection::Delegate {
1220 public:
HandlePacket(DnsResponse * response,int size)1221 void HandlePacket(DnsResponse* response, int size) override {
1222 HandlePacketInternal(std::string(response->io_buffer()->data(), size));
1223 }
1224
1225 MOCK_METHOD1(HandlePacketInternal, void(std::string packet));
1226
1227 MOCK_METHOD1(OnConnectionError, void(int error));
1228 };
1229
1230 class MDnsConnectionTest : public TestWithTaskEnvironment {
1231 public:
MDnsConnectionTest()1232 MDnsConnectionTest() : connection_(&delegate_) {
1233 }
1234
1235 protected:
1236 // Follow successful connection initialization.
SetUp()1237 void SetUp() override {
1238 socket_ipv4_ = new MockMDnsDatagramServerSocket(ADDRESS_FAMILY_IPV4);
1239 socket_ipv6_ = new MockMDnsDatagramServerSocket(ADDRESS_FAMILY_IPV6);
1240 factory_.PushSocket(base::WrapUnique(socket_ipv6_));
1241 factory_.PushSocket(base::WrapUnique(socket_ipv4_));
1242 sample_packet_ = MakeString(kSamplePacket1, sizeof(kSamplePacket1));
1243 sample_buffer_ = base::MakeRefCounted<StringIOBuffer>(sample_packet_);
1244 }
1245
InitConnection()1246 int InitConnection() { return connection_.Init(&factory_); }
1247
1248 StrictMock<MockMDnsConnectionDelegate> delegate_;
1249
1250 MockMDnsDatagramServerSocket* socket_ipv4_;
1251 MockMDnsDatagramServerSocket* socket_ipv6_;
1252 SimpleMockSocketFactory factory_;
1253 MDnsConnection connection_;
1254 TestCompletionCallback callback_;
1255 std::string sample_packet_;
1256 scoped_refptr<IOBuffer> sample_buffer_;
1257 };
1258
TEST_F(MDnsConnectionTest,ReceiveSynchronous)1259 TEST_F(MDnsConnectionTest, ReceiveSynchronous) {
1260 socket_ipv6_->SetResponsePacket(sample_packet_);
1261 EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1262 .WillOnce(Return(ERR_IO_PENDING));
1263 EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1264 .WillOnce(
1265 Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvNow))
1266 .WillOnce(Return(ERR_IO_PENDING));
1267
1268 EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet_));
1269 EXPECT_THAT(InitConnection(), test::IsOk());
1270 }
1271
TEST_F(MDnsConnectionTest,ReceiveAsynchronous)1272 TEST_F(MDnsConnectionTest, ReceiveAsynchronous) {
1273 socket_ipv6_->SetResponsePacket(sample_packet_);
1274
1275 EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1276 .WillOnce(Return(ERR_IO_PENDING));
1277 EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1278 .Times(2)
1279 .WillOnce(
1280 Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvLater))
1281 .WillOnce(Return(ERR_IO_PENDING));
1282
1283 ASSERT_THAT(InitConnection(), test::IsOk());
1284
1285 EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet_));
1286
1287 base::RunLoop().RunUntilIdle();
1288 }
1289
TEST_F(MDnsConnectionTest,Error)1290 TEST_F(MDnsConnectionTest, Error) {
1291 CompletionRepeatingCallback callback;
1292
1293 EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1294 .WillOnce(Return(ERR_IO_PENDING));
1295 EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1296 .WillOnce(DoAll(SaveArg<3>(&callback), Return(ERR_IO_PENDING)));
1297
1298 ASSERT_THAT(InitConnection(), test::IsOk());
1299
1300 EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED));
1301 callback.Run(ERR_SOCKET_NOT_CONNECTED);
1302 base::RunLoop().RunUntilIdle();
1303 }
1304
1305 class MDnsConnectionSendTest : public MDnsConnectionTest {
1306 protected:
SetUp()1307 void SetUp() override {
1308 MDnsConnectionTest::SetUp();
1309 EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1310 .WillOnce(Return(ERR_IO_PENDING));
1311 EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1312 .WillOnce(Return(ERR_IO_PENDING));
1313 EXPECT_THAT(InitConnection(), test::IsOk());
1314 }
1315 };
1316
TEST_F(MDnsConnectionSendTest,Send)1317 TEST_F(MDnsConnectionSendTest, Send) {
1318 EXPECT_CALL(*socket_ipv4_,
1319 SendToInternal(sample_packet_, "224.0.0.251:5353", _));
1320 EXPECT_CALL(*socket_ipv6_,
1321 SendToInternal(sample_packet_, "[ff02::fb]:5353", _));
1322
1323 connection_.Send(sample_buffer_, sample_packet_.size());
1324 }
1325
TEST_F(MDnsConnectionSendTest,SendError)1326 TEST_F(MDnsConnectionSendTest, SendError) {
1327 EXPECT_CALL(*socket_ipv4_,
1328 SendToInternal(sample_packet_, "224.0.0.251:5353", _));
1329 EXPECT_CALL(*socket_ipv6_,
1330 SendToInternal(sample_packet_, "[ff02::fb]:5353", _))
1331 .WillOnce(Return(ERR_SOCKET_NOT_CONNECTED));
1332
1333 connection_.Send(sample_buffer_, sample_packet_.size());
1334 EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED));
1335 base::RunLoop().RunUntilIdle();
1336 }
1337
TEST_F(MDnsConnectionSendTest,SendQueued)1338 TEST_F(MDnsConnectionSendTest, SendQueued) {
1339 // Send data immediately.
1340 EXPECT_CALL(*socket_ipv4_,
1341 SendToInternal(sample_packet_, "224.0.0.251:5353", _))
1342 .Times(2)
1343 .WillRepeatedly(Return(OK));
1344
1345 CompletionRepeatingCallback callback;
1346 // Delay sending data. Only the first call should be made.
1347 EXPECT_CALL(*socket_ipv6_,
1348 SendToInternal(sample_packet_, "[ff02::fb]:5353", _))
1349 .WillOnce(DoAll(SaveArg<2>(&callback), Return(ERR_IO_PENDING)));
1350
1351 connection_.Send(sample_buffer_, sample_packet_.size());
1352 connection_.Send(sample_buffer_, sample_packet_.size());
1353
1354 // The second IPv6 packet is not sent yet.
1355 EXPECT_CALL(*socket_ipv4_,
1356 SendToInternal(sample_packet_, "224.0.0.251:5353", _))
1357 .Times(0);
1358 // Expect call for the second IPv6 packet.
1359 EXPECT_CALL(*socket_ipv6_,
1360 SendToInternal(sample_packet_, "[ff02::fb]:5353", _))
1361 .WillOnce(Return(OK));
1362 callback.Run(OK);
1363 }
1364
TEST(MDnsSocketTest,CreateSocket)1365 TEST(MDnsSocketTest, CreateSocket) {
1366 // Verifies that socket creation hasn't been broken.
1367 auto socket = CreateAndBindMDnsSocket(AddressFamily::ADDRESS_FAMILY_IPV4, 1,
1368 net::NetLog::Get());
1369 EXPECT_TRUE(socket);
1370 socket->Close();
1371 }
1372
1373 } // namespace net
1374