1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <memory>
6 #include <utility>
7 #include <vector>
8 
9 #include "base/bind.h"
10 #include "base/location.h"
11 #include "base/macros.h"
12 #include "base/memory/ptr_util.h"
13 #include "base/memory/ref_counted.h"
14 #include "base/run_loop.h"
15 #include "base/single_thread_task_runner.h"
16 #include "base/test/simple_test_clock.h"
17 #include "base/threading/thread_task_runner_handle.h"
18 #include "base/time/clock.h"
19 #include "base/time/default_clock.h"
20 #include "base/timer/mock_timer.h"
21 #include "base/timer/timer.h"
22 #include "net/base/address_family.h"
23 #include "net/base/completion_repeating_callback.h"
24 #include "net/base/ip_address.h"
25 #include "net/base/rand_callback.h"
26 #include "net/base/test_completion_callback.h"
27 #include "net/dns/mdns_client_impl.h"
28 #include "net/dns/mock_mdns_socket_factory.h"
29 #include "net/dns/record_rdata.h"
30 #include "net/log/net_log.h"
31 #include "net/socket/udp_client_socket.h"
32 #include "net/test/gtest_util.h"
33 #include "net/test/test_with_task_environment.h"
34 #include "testing/gmock/include/gmock/gmock.h"
35 #include "testing/gtest/include/gtest/gtest.h"
36 
37 using ::testing::_;
38 using ::testing::Assign;
39 using ::testing::AtMost;
40 using ::testing::Exactly;
41 using ::testing::IgnoreResult;
42 using ::testing::Invoke;
43 using ::testing::InvokeWithoutArgs;
44 using ::testing::NiceMock;
45 using ::testing::Return;
46 using ::testing::SaveArg;
47 using ::testing::StrictMock;
48 
49 namespace net {
50 
51 namespace {
52 
53 const uint8_t kSamplePacket1[] = {
54     // Header
55     0x00, 0x00,  // ID is zeroed out
56     0x81, 0x80,  // Standard query response, RA, no error
57     0x00, 0x00,  // No questions (for simplicity)
58     0x00, 0x02,  // 2 RRs (answers)
59     0x00, 0x00,  // 0 authority RRs
60     0x00, 0x00,  // 0 additional RRs
61 
62     // Answer 1
63     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
64     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
65     0x00, 0x01,                                 // CLASS is IN.
66     0x00, 0x00,                                 // TTL (4 bytes) is 1 second;
67     0x00, 0x01, 0x00, 0x08,                     // RDLENGTH is 8 bytes.
68     0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
69 
70     // Answer 2
71     0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', 0xc0,
72     0x14,        // Pointer to "._tcp.local"
73     0x00, 0x0c,  // TYPE is PTR.
74     0x00, 0x01,  // CLASS is IN.
75     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
76     0x24, 0x75, 0x00, 0x08,  // RDLENGTH is 8 bytes.
77     0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x32};
78 
79 const uint8_t kCorruptedPacketBadQuestion[] = {
80     // Header
81     0x00, 0x00,  // ID is zeroed out
82     0x81, 0x80,  // Standard query response, RA, no error
83     0x00, 0x01,  // One question
84     0x00, 0x02,  // 2 RRs (answers)
85     0x00, 0x00,  // 0 authority RRs
86     0x00, 0x00,  // 0 additional RRs
87 
88     // Question is corrupted and cannot be read.
89     0x99, 'h', 'e', 'l', 'l', 'o', 0x00, 0x00, 0x00, 0x00, 0x00,
90 
91     // Answer 1
92     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
93     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
94     0x00, 0x01,                                 // CLASS is IN.
95     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
96     0x24, 0x74, 0x00, 0x99,  // RDLENGTH is impossible
97     0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
98 
99     // Answer 2
100     0x08, '_', 'p', 'r',  // Useless trailing data.
101 };
102 
103 const uint8_t kCorruptedPacketUnsalvagable[] = {
104     // Header
105     0x00, 0x00,  // ID is zeroed out
106     0x81, 0x80,  // Standard query response, RA, no error
107     0x00, 0x00,  // No questions (for simplicity)
108     0x00, 0x02,  // 2 RRs (answers)
109     0x00, 0x00,  // 0 authority RRs
110     0x00, 0x00,  // 0 additional RRs
111 
112     // Answer 1
113     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
114     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
115     0x00, 0x01,                                 // CLASS is IN.
116     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
117     0x24, 0x74, 0x00, 0x99,  // RDLENGTH is impossible
118     0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
119 
120     // Answer 2
121     0x08, '_', 'p', 'r',  // Useless trailing data.
122 };
123 
124 const uint8_t kCorruptedPacketDoubleRecord[] = {
125     // Header
126     0x00, 0x00,  // ID is zeroed out
127     0x81, 0x80,  // Standard query response, RA, no error
128     0x00, 0x00,  // No questions (for simplicity)
129     0x00, 0x02,  // 2 RRs (answers)
130     0x00, 0x00,  // 0 authority RRs
131     0x00, 0x00,  // 0 additional RRs
132 
133     // Answer 1
134     0x06, 'p', 'r', 'i', 'v', 'e', 't', 0x05, 'l', 'o', 'c', 'a', 'l', 0x00,
135     0x00, 0x01,  // TYPE is A.
136     0x00, 0x01,  // CLASS is IN.
137     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
138     0x24, 0x74, 0x00, 0x04,  // RDLENGTH is 4
139     0x05, 0x03, 0xc0, 0x0c,
140 
141     // Answer 2 -- Same key
142     0x06, 'p', 'r', 'i', 'v', 'e', 't', 0x05, 'l', 'o', 'c', 'a', 'l', 0x00,
143     0x00, 0x01,  // TYPE is A.
144     0x00, 0x01,  // CLASS is IN.
145     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
146     0x24, 0x74, 0x00, 0x04,  // RDLENGTH is 4
147     0x02, 0x03, 0x04, 0x05,
148 };
149 
150 const uint8_t kCorruptedPacketSalvagable[] = {
151     // Header
152     0x00, 0x00,  // ID is zeroed out
153     0x81, 0x80,  // Standard query response, RA, no error
154     0x00, 0x00,  // No questions (for simplicity)
155     0x00, 0x02,  // 2 RRs (answers)
156     0x00, 0x00,  // 0 authority RRs
157     0x00, 0x00,  // 0 additional RRs
158 
159     // Answer 1
160     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
161     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
162     0x00, 0x01,                                 // CLASS is IN.
163     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
164     0x24, 0x74, 0x00, 0x08,         // RDLENGTH is 8 bytes.
165     0x99, 'h', 'e', 'l', 'l', 'o',  // Bad RDATA format.
166     0xc0, 0x0c,
167 
168     // Answer 2
169     0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', 0xc0,
170     0x14,        // Pointer to "._tcp.local"
171     0x00, 0x0c,  // TYPE is PTR.
172     0x00, 0x01,  // CLASS is IN.
173     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
174     0x24, 0x75, 0x00, 0x08,  // RDLENGTH is 8 bytes.
175     0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x32};
176 
177 const uint8_t kSamplePacket2[] = {
178     // Header
179     0x00, 0x00,  // ID is zeroed out
180     0x81, 0x80,  // Standard query response, RA, no error
181     0x00, 0x00,  // No questions (for simplicity)
182     0x00, 0x02,  // 2 RRs (answers)
183     0x00, 0x00,  // 0 authority RRs
184     0x00, 0x00,  // 0 additional RRs
185 
186     // Answer 1
187     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
188     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
189     0x00, 0x01,                                 // CLASS is IN.
190     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
191     0x24, 0x74, 0x00, 0x08,  // RDLENGTH is 8 bytes.
192     0x05, 'z', 'z', 'z', 'z', 'z', 0xc0, 0x0c,
193 
194     // Answer 2
195     0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', 0xc0,
196     0x14,        // Pointer to "._tcp.local"
197     0x00, 0x0c,  // TYPE is PTR.
198     0x00, 0x01,  // CLASS is IN.
199     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
200     0x24, 0x74, 0x00, 0x08,  // RDLENGTH is 8 bytes.
201     0x05, 'z', 'z', 'z', 'z', 'z', 0xc0, 0x32};
202 
203 const uint8_t kSamplePacket3[] = {
204     // Header
205     0x00, 0x00,  // ID is zeroed out
206     0x81, 0x80,  // Standard query response, RA, no error
207     0x00, 0x00,  // No questions (for simplicity)
208     0x00, 0x02,  // 2 RRs (answers)
209     0x00, 0x00,  // 0 authority RRs
210     0x00, 0x00,  // 0 additional RRs
211 
212     // Answer 1
213     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',  //
214     0x04, '_', 't', 'c', 'p',                 //
215     0x05, 'l', 'o', 'c', 'a', 'l',            //
216     0x00, 0x00, 0x0c,                         // TYPE is PTR.
217     0x00, 0x01,                               // CLASS is IN.
218     0x00, 0x00,                               // TTL (4 bytes) is 1 second;
219     0x00, 0x01,                               //
220     0x00, 0x08,                               // RDLENGTH is 8 bytes.
221     0x05, 'h', 'e', 'l', 'l', 'o',            //
222     0xc0, 0x0c,                               //
223 
224     // Answer 2
225     0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r',  //
226     0xc0, 0x14,                                    // Pointer to "._tcp.local"
227     0x00, 0x0c,                                    // TYPE is PTR.
228     0x00, 0x01,                                    // CLASS is IN.
229     0x00, 0x00,                     // TTL (4 bytes) is 3 seconds.
230     0x00, 0x03,                     //
231     0x00, 0x08,                     // RDLENGTH is 8 bytes.
232     0x05, 'h', 'e', 'l', 'l', 'o',  //
233     0xc0, 0x32};
234 
235 const uint8_t kQueryPacketPrivet[] = {
236     // Header
237     0x00, 0x00,  // ID is zeroed out
238     0x00, 0x00,  // No flags.
239     0x00, 0x01,  // One question.
240     0x00, 0x00,  // 0 RRs (answers)
241     0x00, 0x00,  // 0 authority RRs
242     0x00, 0x00,  // 0 additional RRs
243 
244     // Question
245     // This part is echoed back from the respective query.
246     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
247     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
248     0x00, 0x01,                                 // CLASS is IN.
249 };
250 
251 const uint8_t kQueryPacketPrivetA[] = {
252     // Header
253     0x00, 0x00,  // ID is zeroed out
254     0x00, 0x00,  // No flags.
255     0x00, 0x01,  // One question.
256     0x00, 0x00,  // 0 RRs (answers)
257     0x00, 0x00,  // 0 authority RRs
258     0x00, 0x00,  // 0 additional RRs
259 
260     // Question
261     // This part is echoed back from the respective query.
262     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
263     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x01,  // TYPE is A.
264     0x00, 0x01,                                 // CLASS is IN.
265 };
266 
267 const uint8_t kSamplePacketAdditionalOnly[] = {
268     // Header
269     0x00, 0x00,  // ID is zeroed out
270     0x81, 0x80,  // Standard query response, RA, no error
271     0x00, 0x00,  // No questions (for simplicity)
272     0x00, 0x00,  // 2 RRs (answers)
273     0x00, 0x00,  // 0 authority RRs
274     0x00, 0x01,  // 0 additional RRs
275 
276     // Answer 1
277     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
278     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
279     0x00, 0x01,                                 // CLASS is IN.
280     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
281     0x24, 0x74, 0x00, 0x08,  // RDLENGTH is 8 bytes.
282     0x05, 'h', 'e', 'l', 'l', 'o', 0xc0, 0x0c,
283 };
284 
285 const uint8_t kSamplePacketNsec[] = {
286     // Header
287     0x00, 0x00,  // ID is zeroed out
288     0x81, 0x80,  // Standard query response, RA, no error
289     0x00, 0x00,  // No questions (for simplicity)
290     0x00, 0x01,  // 1 RR (answers)
291     0x00, 0x00,  // 0 authority RRs
292     0x00, 0x00,  // 0 additional RRs
293 
294     // Answer 1
295     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
296     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x2f,  // TYPE is NSEC.
297     0x00, 0x01,                                 // CLASS is IN.
298     0x00, 0x01,  // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
299     0x24, 0x74, 0x00, 0x06,             // RDLENGTH is 6 bytes.
300     0xc0, 0x0c, 0x00, 0x02, 0x00, 0x08  // Only A record present
301 };
302 
303 const uint8_t kSamplePacketAPrivet[] = {
304     // Header
305     0x00, 0x00,  // ID is zeroed out
306     0x81, 0x80,  // Standard query response, RA, no error
307     0x00, 0x00,  // No questions (for simplicity)
308     0x00, 0x01,  // 1 RR (answers)
309     0x00, 0x00,  // 0 authority RRs
310     0x00, 0x00,  // 0 additional RRs
311 
312     // Answer 1
313     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
314     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x01,  // TYPE is A.
315     0x00, 0x01,                                 // CLASS is IN.
316     0x00, 0x00,                                 // TTL (4 bytes) is 5 seconds
317     0x00, 0x05, 0x00, 0x04,                     // RDLENGTH is 4 bytes.
318     0xc0, 0x0c, 0x00, 0x02,
319 };
320 
321 const uint8_t kSamplePacketGoodbye[] = {
322     // Header
323     0x00, 0x00,  // ID is zeroed out
324     0x81, 0x80,  // Standard query response, RA, no error
325     0x00, 0x00,  // No questions (for simplicity)
326     0x00, 0x01,  // 2 RRs (answers)
327     0x00, 0x00,  // 0 authority RRs
328     0x00, 0x00,  // 0 additional RRs
329 
330     // Answer 1
331     0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 0x04, '_', 't', 'c', 'p', 0x05,
332     'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x0c,  // TYPE is PTR.
333     0x00, 0x01,                                 // CLASS is IN.
334     0x00, 0x00,                                 // TTL (4 bytes) is zero;
335     0x00, 0x00, 0x00, 0x08,                     // RDLENGTH is 8 bytes.
336     0x05, 'z', 'z', 'z', 'z', 'z', 0xc0, 0x0c,
337 };
338 
MakeString(const uint8_t * data,unsigned size)339 std::string MakeString(const uint8_t* data, unsigned size) {
340   return std::string(reinterpret_cast<const char*>(data), size);
341 }
342 
343 class PtrRecordCopyContainer {
344  public:
345   PtrRecordCopyContainer() = default;
346   ~PtrRecordCopyContainer() = default;
347 
is_set() const348   bool is_set() const { return set_; }
349 
SaveWithDummyArg(int unused,const RecordParsed * value)350   void SaveWithDummyArg(int unused, const RecordParsed* value) {
351     Save(value);
352   }
353 
Save(const RecordParsed * value)354   void Save(const RecordParsed* value) {
355     set_ = true;
356     name_ = value->name();
357     ptrdomain_ = value->rdata<PtrRecordRdata>()->ptrdomain();
358     ttl_ = value->ttl();
359   }
360 
IsRecordWith(const std::string & name,const std::string & ptrdomain)361   bool IsRecordWith(const std::string& name, const std::string& ptrdomain) {
362     return set_ && name_ == name && ptrdomain_ == ptrdomain;
363   }
364 
name()365   const std::string& name() { return name_; }
ptrdomain()366   const std::string& ptrdomain() { return ptrdomain_; }
ttl()367   int ttl() { return ttl_; }
368 
369  private:
370   bool set_;
371   std::string name_;
372   std::string ptrdomain_;
373   int ttl_;
374 };
375 
376 class MockClock : public base::Clock {
377  public:
378   MockClock() = default;
379   ~MockClock() override = default;
380 
381   MOCK_CONST_METHOD0(Now, base::Time());
382 
383  private:
384   DISALLOW_COPY_AND_ASSIGN(MockClock);
385 };
386 
387 class MockTimer : public base::MockOneShotTimer {
388  public:
MockTimer()389   MockTimer() {}
390   ~MockTimer() override = default;
391 
Start(const base::Location & posted_from,base::TimeDelta delay,base::OnceClosure user_task)392   void Start(const base::Location& posted_from,
393              base::TimeDelta delay,
394              base::OnceClosure user_task) override {
395     StartObserver(posted_from, delay);
396     base::MockOneShotTimer::Start(posted_from, delay, std::move(user_task));
397   }
398 
399   // StartObserver is invoked when MockTimer::Start() is called.
400   // Does not replace the behavior of MockTimer::Start().
401   MOCK_METHOD2(StartObserver,
402                void(const base::Location& posted_from, base::TimeDelta delay));
403 
404  private:
405   DISALLOW_COPY_AND_ASSIGN(MockTimer);
406 };
407 
408 }  // namespace
409 
410 class MDnsTest : public TestWithTaskEnvironment {
411  public:
412   void SetUp() override;
413   void DeleteTransaction();
414   void DeleteBothListeners();
415   void RunFor(base::TimeDelta time_period);
416   void Stop();
417 
418   MOCK_METHOD2(MockableRecordCallback, void(MDnsTransaction::Result result,
419                                             const RecordParsed* record));
420 
421   MOCK_METHOD2(MockableRecordCallback2, void(MDnsTransaction::Result result,
422                                              const RecordParsed* record));
423 
424  protected:
425   void ExpectPacket(const uint8_t* packet, unsigned size);
426   void SimulatePacketReceive(const uint8_t* packet, unsigned size);
427 
428   std::unique_ptr<MDnsClientImpl> test_client_;
429   IPEndPoint mdns_ipv4_endpoint_;
430   StrictMock<MockMDnsSocketFactory> socket_factory_;
431 
432   // Transactions and listeners that can be deleted by class methods for
433   // reentrancy tests.
434   std::unique_ptr<MDnsTransaction> transaction_;
435   std::unique_ptr<MDnsListener> listener1_;
436   std::unique_ptr<MDnsListener> listener2_;
437 };
438 
439 class MockListenerDelegate : public MDnsListener::Delegate {
440  public:
441   MOCK_METHOD2(OnRecordUpdate,
442                void(MDnsListener::UpdateType update,
443                     const RecordParsed* records));
444   MOCK_METHOD2(OnNsecRecord, void(const std::string&, unsigned));
445   MOCK_METHOD0(OnCachePurged, void());
446 };
447 
SetUp()448 void MDnsTest::SetUp() {
449   test_client_.reset(new MDnsClientImpl());
450   ASSERT_THAT(test_client_->StartListening(&socket_factory_), test::IsOk());
451 }
452 
SimulatePacketReceive(const uint8_t * packet,unsigned size)453 void MDnsTest::SimulatePacketReceive(const uint8_t* packet, unsigned size) {
454   socket_factory_.SimulateReceive(packet, size);
455 }
456 
ExpectPacket(const uint8_t * packet,unsigned size)457 void MDnsTest::ExpectPacket(const uint8_t* packet, unsigned size) {
458   EXPECT_CALL(socket_factory_, OnSendTo(MakeString(packet, size)))
459       .Times(2);
460 }
461 
DeleteTransaction()462 void MDnsTest::DeleteTransaction() {
463   transaction_.reset();
464 }
465 
DeleteBothListeners()466 void MDnsTest::DeleteBothListeners() {
467   listener1_.reset();
468   listener2_.reset();
469 }
470 
RunFor(base::TimeDelta time_period)471 void MDnsTest::RunFor(base::TimeDelta time_period) {
472   base::CancelableCallback<void()> callback(base::Bind(&MDnsTest::Stop,
473                                                        base::Unretained(this)));
474   base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
475       FROM_HERE, callback.callback(), time_period);
476 
477   base::RunLoop().Run();
478   callback.Cancel();
479 }
480 
Stop()481 void MDnsTest::Stop() {
482   base::RunLoop::QuitCurrentWhenIdleDeprecated();
483 }
484 
TEST_F(MDnsTest,PassiveListeners)485 TEST_F(MDnsTest, PassiveListeners) {
486   StrictMock<MockListenerDelegate> delegate_privet;
487   StrictMock<MockListenerDelegate> delegate_printer;
488 
489   PtrRecordCopyContainer record_privet;
490   PtrRecordCopyContainer record_printer;
491 
492   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
493       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
494   std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
495       dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
496 
497   ASSERT_TRUE(listener_privet->Start());
498   ASSERT_TRUE(listener_printer->Start());
499 
500   // Send the same packet twice to ensure no records are double-counted.
501 
502   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
503       .Times(Exactly(1))
504       .WillOnce(Invoke(
505           &record_privet,
506           &PtrRecordCopyContainer::SaveWithDummyArg));
507 
508   EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
509       .Times(Exactly(1))
510       .WillOnce(Invoke(
511           &record_printer,
512           &PtrRecordCopyContainer::SaveWithDummyArg));
513 
514 
515   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
516   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
517 
518   EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
519                                          "hello._privet._tcp.local"));
520 
521   EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
522                                           "hello._printer._tcp.local"));
523 
524   listener_privet.reset();
525   listener_printer.reset();
526 }
527 
TEST_F(MDnsTest,PassiveListenersCacheCleanup)528 TEST_F(MDnsTest, PassiveListenersCacheCleanup) {
529   StrictMock<MockListenerDelegate> delegate_privet;
530 
531   PtrRecordCopyContainer record_privet;
532   PtrRecordCopyContainer record_privet2;
533 
534   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
535       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
536 
537   ASSERT_TRUE(listener_privet->Start());
538 
539   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
540       .Times(Exactly(1))
541       .WillOnce(Invoke(
542           &record_privet,
543           &PtrRecordCopyContainer::SaveWithDummyArg));
544 
545   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
546 
547   EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
548                                          "hello._privet._tcp.local"));
549 
550   // Expect record is removed when its TTL expires.
551   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
552       .Times(Exactly(1))
553       .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::Stop),
554                       Invoke(&record_privet2,
555                              &PtrRecordCopyContainer::SaveWithDummyArg)));
556 
557   RunFor(base::TimeDelta::FromSeconds(record_privet.ttl() + 1));
558 
559   EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
560                                           "hello._privet._tcp.local"));
561 }
562 
563 // Ensure that the cleanup task scheduler won't schedule cleanup tasks in the
564 // past if the system clock creeps past the expiration time while in the
565 // cleanup dispatcher.
TEST_F(MDnsTest,CacheCleanupWithShortTTL)566 TEST_F(MDnsTest, CacheCleanupWithShortTTL) {
567   // Use a nonzero starting time as a base.
568   base::Time start_time = base::Time() + base::TimeDelta::FromSeconds(1);
569 
570   MockClock clock;
571   MockTimer* timer = new MockTimer;
572 
573   test_client_.reset(new MDnsClientImpl(&clock, base::WrapUnique(timer)));
574   ASSERT_THAT(test_client_->StartListening(&socket_factory_), test::IsOk());
575 
576   EXPECT_CALL(*timer, StartObserver(_, _)).Times(1);
577   EXPECT_CALL(clock, Now())
578       .Times(3)
579       .WillRepeatedly(Return(start_time))
580       .RetiresOnSaturation();
581 
582   // Receive two records with different TTL values.
583   // TTL(privet)=1.0s
584   // TTL(printer)=3.0s
585   StrictMock<MockListenerDelegate> delegate_privet;
586   StrictMock<MockListenerDelegate> delegate_printer;
587 
588   PtrRecordCopyContainer record_privet;
589   PtrRecordCopyContainer record_printer;
590 
591   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
592       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
593   std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
594       dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
595 
596   ASSERT_TRUE(listener_privet->Start());
597   ASSERT_TRUE(listener_printer->Start());
598 
599   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
600       .Times(Exactly(1));
601   EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
602       .Times(Exactly(1));
603 
604   SimulatePacketReceive(kSamplePacket3, sizeof(kSamplePacket3));
605 
606   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
607       .Times(Exactly(1));
608 
609   // Set the clock to 2.0s, which should clean up the 'privet' record, but not
610   // the printer. The mock clock will change Now() mid-execution from 2s to 4s.
611   // Note: expectations are FILO-ordered -- t+2 seconds is returned, then t+4.
612   EXPECT_CALL(clock, Now())
613       .WillOnce(Return(start_time + base::TimeDelta::FromSeconds(4)))
614       .RetiresOnSaturation();
615   EXPECT_CALL(clock, Now())
616       .WillOnce(Return(start_time + base::TimeDelta::FromSeconds(2)))
617       .RetiresOnSaturation();
618 
619   EXPECT_CALL(*timer, StartObserver(_, base::TimeDelta()));
620 
621   timer->Fire();
622 }
623 
TEST_F(MDnsTest,StopListening)624 TEST_F(MDnsTest, StopListening) {
625   ASSERT_TRUE(test_client_->IsListening());
626 
627   test_client_->StopListening();
628   EXPECT_FALSE(test_client_->IsListening());
629 }
630 
TEST_F(MDnsTest,StopListening_CacheCleanupScheduled)631 TEST_F(MDnsTest, StopListening_CacheCleanupScheduled) {
632   base::SimpleTestClock clock;
633   // Use a nonzero starting time as a base.
634   clock.SetNow(base::Time() + base::TimeDelta::FromSeconds(1));
635   auto cleanup_timer = std::make_unique<base::MockOneShotTimer>();
636   base::OneShotTimer* cleanup_timer_ptr = cleanup_timer.get();
637 
638   test_client_ =
639       std::make_unique<MDnsClientImpl>(&clock, std::move(cleanup_timer));
640   ASSERT_THAT(test_client_->StartListening(&socket_factory_), test::IsOk());
641   ASSERT_TRUE(test_client_->IsListening());
642 
643   // Receive one record (privet) with TTL=1s to schedule cleanup.
644   SimulatePacketReceive(kSamplePacket3, sizeof(kSamplePacket3));
645   ASSERT_TRUE(cleanup_timer_ptr->IsRunning());
646 
647   test_client_->StopListening();
648   EXPECT_FALSE(test_client_->IsListening());
649 
650   // Expect cleanup unscheduled.
651   EXPECT_FALSE(cleanup_timer_ptr->IsRunning());
652 }
653 
TEST_F(MDnsTest,MalformedPacket)654 TEST_F(MDnsTest, MalformedPacket) {
655   StrictMock<MockListenerDelegate> delegate_printer;
656 
657   PtrRecordCopyContainer record_printer;
658 
659   std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
660       dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
661 
662   ASSERT_TRUE(listener_printer->Start());
663 
664   EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
665       .Times(Exactly(1))
666       .WillOnce(Invoke(
667           &record_printer,
668           &PtrRecordCopyContainer::SaveWithDummyArg));
669 
670   // First, send unsalvagable packet to ensure we can deal with it.
671   SimulatePacketReceive(kCorruptedPacketUnsalvagable,
672                         sizeof(kCorruptedPacketUnsalvagable));
673 
674   // Regression test: send a packet where the question cannot be read.
675   SimulatePacketReceive(kCorruptedPacketBadQuestion,
676                         sizeof(kCorruptedPacketBadQuestion));
677 
678   // Then send salvagable packet to ensure we can extract useful records.
679   SimulatePacketReceive(kCorruptedPacketSalvagable,
680                         sizeof(kCorruptedPacketSalvagable));
681 
682   EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
683                                           "hello._printer._tcp.local"));
684 }
685 
TEST_F(MDnsTest,TransactionWithEmptyCache)686 TEST_F(MDnsTest, TransactionWithEmptyCache) {
687   ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
688 
689   std::unique_ptr<MDnsTransaction> transaction_privet =
690       test_client_->CreateTransaction(
691           dns_protocol::kTypePTR, "_privet._tcp.local",
692           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
693               MDnsTransaction::SINGLE_RESULT,
694           base::BindRepeating(&MDnsTest::MockableRecordCallback,
695                               base::Unretained(this)));
696 
697   ASSERT_TRUE(transaction_privet->Start());
698 
699   PtrRecordCopyContainer record_privet;
700 
701   EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
702       .Times(Exactly(1))
703       .WillOnce(Invoke(&record_privet,
704                        &PtrRecordCopyContainer::SaveWithDummyArg));
705 
706   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
707 
708   EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
709                                          "hello._privet._tcp.local"));
710 }
711 
TEST_F(MDnsTest,TransactionCacheOnlyNoResult)712 TEST_F(MDnsTest, TransactionCacheOnlyNoResult) {
713   std::unique_ptr<MDnsTransaction> transaction_privet =
714       test_client_->CreateTransaction(
715           dns_protocol::kTypePTR, "_privet._tcp.local",
716           MDnsTransaction::QUERY_CACHE | MDnsTransaction::SINGLE_RESULT,
717           base::BindRepeating(&MDnsTest::MockableRecordCallback,
718                               base::Unretained(this)));
719 
720   EXPECT_CALL(*this,
721               MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, _))
722       .Times(Exactly(1));
723 
724   ASSERT_TRUE(transaction_privet->Start());
725 }
726 
TEST_F(MDnsTest,TransactionWithCache)727 TEST_F(MDnsTest, TransactionWithCache) {
728   // Listener to force the client to listen
729   StrictMock<MockListenerDelegate> delegate_irrelevant;
730   std::unique_ptr<MDnsListener> listener_irrelevant =
731       test_client_->CreateListener(dns_protocol::kTypeA,
732                                    "codereview.chromium.local",
733                                    &delegate_irrelevant);
734 
735   ASSERT_TRUE(listener_irrelevant->Start());
736 
737   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
738 
739 
740   PtrRecordCopyContainer record_privet;
741 
742   EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
743       .WillOnce(Invoke(&record_privet,
744                        &PtrRecordCopyContainer::SaveWithDummyArg));
745 
746   std::unique_ptr<MDnsTransaction> transaction_privet =
747       test_client_->CreateTransaction(
748           dns_protocol::kTypePTR, "_privet._tcp.local",
749           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
750               MDnsTransaction::SINGLE_RESULT,
751           base::BindRepeating(&MDnsTest::MockableRecordCallback,
752                               base::Unretained(this)));
753 
754   ASSERT_TRUE(transaction_privet->Start());
755 
756   EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
757                                          "hello._privet._tcp.local"));
758 }
759 
TEST_F(MDnsTest,AdditionalRecords)760 TEST_F(MDnsTest, AdditionalRecords) {
761   StrictMock<MockListenerDelegate> delegate_privet;
762 
763   PtrRecordCopyContainer record_privet;
764 
765   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
766       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
767 
768   ASSERT_TRUE(listener_privet->Start());
769 
770   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
771       .Times(Exactly(1))
772       .WillOnce(Invoke(
773           &record_privet,
774           &PtrRecordCopyContainer::SaveWithDummyArg));
775 
776   SimulatePacketReceive(kSamplePacketAdditionalOnly,
777                         sizeof(kSamplePacketAdditionalOnly));
778 
779   EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
780                                          "hello._privet._tcp.local"));
781 }
782 
TEST_F(MDnsTest,TransactionTimeout)783 TEST_F(MDnsTest, TransactionTimeout) {
784   ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
785 
786   std::unique_ptr<MDnsTransaction> transaction_privet =
787       test_client_->CreateTransaction(
788           dns_protocol::kTypePTR, "_privet._tcp.local",
789           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
790               MDnsTransaction::SINGLE_RESULT,
791           base::BindRepeating(&MDnsTest::MockableRecordCallback,
792                               base::Unretained(this)));
793 
794   ASSERT_TRUE(transaction_privet->Start());
795 
796   EXPECT_CALL(*this,
797               MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL))
798       .Times(Exactly(1))
799       .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
800 
801   RunFor(base::TimeDelta::FromSeconds(4));
802 }
803 
TEST_F(MDnsTest,TransactionMultipleRecords)804 TEST_F(MDnsTest, TransactionMultipleRecords) {
805   ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
806 
807   std::unique_ptr<MDnsTransaction> transaction_privet =
808       test_client_->CreateTransaction(
809           dns_protocol::kTypePTR, "_privet._tcp.local",
810           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE,
811           base::BindRepeating(&MDnsTest::MockableRecordCallback,
812                               base::Unretained(this)));
813 
814   ASSERT_TRUE(transaction_privet->Start());
815 
816   PtrRecordCopyContainer record_privet;
817   PtrRecordCopyContainer record_privet2;
818 
819   EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
820       .Times(Exactly(2))
821       .WillOnce(Invoke(&record_privet,
822                        &PtrRecordCopyContainer::SaveWithDummyArg))
823       .WillOnce(Invoke(&record_privet2,
824                        &PtrRecordCopyContainer::SaveWithDummyArg));
825 
826   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
827   SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
828 
829   EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
830                                          "hello._privet._tcp.local"));
831 
832   EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
833                                           "zzzzz._privet._tcp.local"));
834 
835   EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_DONE, NULL))
836       .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
837 
838   RunFor(base::TimeDelta::FromSeconds(4));
839 }
840 
TEST_F(MDnsTest,TransactionReentrantDelete)841 TEST_F(MDnsTest, TransactionReentrantDelete) {
842   ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
843 
844   transaction_ = test_client_->CreateTransaction(
845       dns_protocol::kTypePTR, "_privet._tcp.local",
846       MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
847           MDnsTransaction::SINGLE_RESULT,
848       base::BindRepeating(&MDnsTest::MockableRecordCallback,
849                           base::Unretained(this)));
850 
851   ASSERT_TRUE(transaction_->Start());
852 
853   EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS,
854                                             NULL))
855       .Times(Exactly(1))
856       .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction),
857                       InvokeWithoutArgs(this, &MDnsTest::Stop)));
858 
859   RunFor(base::TimeDelta::FromSeconds(4));
860 
861   EXPECT_EQ(NULL, transaction_.get());
862 }
863 
TEST_F(MDnsTest,TransactionReentrantDeleteFromCache)864 TEST_F(MDnsTest, TransactionReentrantDeleteFromCache) {
865   StrictMock<MockListenerDelegate> delegate_irrelevant;
866   std::unique_ptr<MDnsListener> listener_irrelevant =
867       test_client_->CreateListener(dns_protocol::kTypeA,
868                                    "codereview.chromium.local",
869                                    &delegate_irrelevant);
870   ASSERT_TRUE(listener_irrelevant->Start());
871 
872   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
873 
874   transaction_ = test_client_->CreateTransaction(
875       dns_protocol::kTypePTR, "_privet._tcp.local",
876       MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE,
877       base::BindRepeating(&MDnsTest::MockableRecordCallback,
878                           base::Unretained(this)));
879 
880   EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
881       .Times(Exactly(1))
882       .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction));
883 
884   ASSERT_TRUE(transaction_->Start());
885 
886   EXPECT_EQ(NULL, transaction_.get());
887 }
888 
TEST_F(MDnsTest,TransactionReentrantCacheLookupStart)889 TEST_F(MDnsTest, TransactionReentrantCacheLookupStart) {
890   ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
891 
892   std::unique_ptr<MDnsTransaction> transaction1 =
893       test_client_->CreateTransaction(
894           dns_protocol::kTypePTR, "_privet._tcp.local",
895           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
896               MDnsTransaction::SINGLE_RESULT,
897           base::BindRepeating(&MDnsTest::MockableRecordCallback,
898                               base::Unretained(this)));
899 
900   std::unique_ptr<MDnsTransaction> transaction2 =
901       test_client_->CreateTransaction(
902           dns_protocol::kTypePTR, "_printer._tcp.local",
903           MDnsTransaction::QUERY_CACHE | MDnsTransaction::SINGLE_RESULT,
904           base::BindRepeating(&MDnsTest::MockableRecordCallback2,
905                               base::Unretained(this)));
906 
907   EXPECT_CALL(*this, MockableRecordCallback2(MDnsTransaction::RESULT_RECORD,
908                                              _))
909       .Times(Exactly(1));
910 
911   EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD,
912                                             _))
913       .Times(Exactly(1))
914       .WillOnce(IgnoreResult(InvokeWithoutArgs(transaction2.get(),
915                                                &MDnsTransaction::Start)));
916 
917   ASSERT_TRUE(transaction1->Start());
918 
919   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
920 }
921 
TEST_F(MDnsTest,GoodbyePacketNotification)922 TEST_F(MDnsTest, GoodbyePacketNotification) {
923   StrictMock<MockListenerDelegate> delegate_privet;
924 
925   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
926       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
927   ASSERT_TRUE(listener_privet->Start());
928 
929   SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye));
930 
931   RunFor(base::TimeDelta::FromSeconds(2));
932 }
933 
TEST_F(MDnsTest,GoodbyePacketRemoval)934 TEST_F(MDnsTest, GoodbyePacketRemoval) {
935   StrictMock<MockListenerDelegate> delegate_privet;
936 
937   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
938       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
939   ASSERT_TRUE(listener_privet->Start());
940 
941   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
942       .Times(Exactly(1));
943 
944   SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
945 
946   SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye));
947 
948   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
949       .Times(Exactly(1));
950 
951   RunFor(base::TimeDelta::FromSeconds(2));
952 }
953 
954 // In order to reliably test reentrant listener deletes, we create two listeners
955 // and have each of them delete both, so we're guaranteed to try and deliver a
956 // callback to at least one deleted listener.
957 
TEST_F(MDnsTest,ListenerReentrantDelete)958 TEST_F(MDnsTest, ListenerReentrantDelete) {
959   StrictMock<MockListenerDelegate> delegate_privet;
960 
961   listener1_ = test_client_->CreateListener(
962       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
963 
964   listener2_ = test_client_->CreateListener(
965       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
966 
967   ASSERT_TRUE(listener1_->Start());
968 
969   ASSERT_TRUE(listener2_->Start());
970 
971   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
972       .Times(Exactly(1))
973       .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteBothListeners));
974 
975   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
976 
977   EXPECT_EQ(NULL, listener1_.get());
978   EXPECT_EQ(NULL, listener2_.get());
979 }
980 
ACTION_P(SaveIPAddress,ip_container)981 ACTION_P(SaveIPAddress, ip_container) {
982   ::testing::StaticAssertTypeEq<const RecordParsed*, arg1_type>();
983   ::testing::StaticAssertTypeEq<IPAddress*, ip_container_type>();
984 
985   *ip_container = arg1->template rdata<ARecordRdata>()->address();
986 }
987 
TEST_F(MDnsTest,DoubleRecordDisagreeing)988 TEST_F(MDnsTest, DoubleRecordDisagreeing) {
989   IPAddress address;
990   StrictMock<MockListenerDelegate> delegate_privet;
991 
992   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
993       dns_protocol::kTypeA, "privet.local", &delegate_privet);
994 
995   ASSERT_TRUE(listener_privet->Start());
996 
997   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
998       .Times(Exactly(1))
999       .WillOnce(SaveIPAddress(&address));
1000 
1001   SimulatePacketReceive(kCorruptedPacketDoubleRecord,
1002                         sizeof(kCorruptedPacketDoubleRecord));
1003 
1004   EXPECT_EQ("2.3.4.5", address.ToString());
1005 }
1006 
TEST_F(MDnsTest,NsecWithListener)1007 TEST_F(MDnsTest, NsecWithListener) {
1008   StrictMock<MockListenerDelegate> delegate_privet;
1009   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1010       dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
1011 
1012   // Test to make sure nsec callback is NOT called for PTR
1013   // (which is marked as existing).
1014   StrictMock<MockListenerDelegate> delegate_privet2;
1015   std::unique_ptr<MDnsListener> listener_privet2 = test_client_->CreateListener(
1016       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet2);
1017 
1018   ASSERT_TRUE(listener_privet->Start());
1019 
1020   EXPECT_CALL(delegate_privet,
1021               OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA));
1022 
1023   SimulatePacketReceive(kSamplePacketNsec,
1024                         sizeof(kSamplePacketNsec));
1025 }
1026 
TEST_F(MDnsTest,NsecWithTransactionFromNetwork)1027 TEST_F(MDnsTest, NsecWithTransactionFromNetwork) {
1028   std::unique_ptr<MDnsTransaction> transaction_privet =
1029       test_client_->CreateTransaction(
1030           dns_protocol::kTypeA, "_privet._tcp.local",
1031           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
1032               MDnsTransaction::SINGLE_RESULT,
1033           base::BindRepeating(&MDnsTest::MockableRecordCallback,
1034                               base::Unretained(this)));
1035 
1036   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
1037 
1038   ASSERT_TRUE(transaction_privet->Start());
1039 
1040   EXPECT_CALL(*this,
1041               MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL));
1042 
1043   SimulatePacketReceive(kSamplePacketNsec,
1044                         sizeof(kSamplePacketNsec));
1045 }
1046 
TEST_F(MDnsTest,NsecWithTransactionFromCache)1047 TEST_F(MDnsTest, NsecWithTransactionFromCache) {
1048   // Force mDNS to listen.
1049   StrictMock<MockListenerDelegate> delegate_irrelevant;
1050   std::unique_ptr<MDnsListener> listener_irrelevant =
1051       test_client_->CreateListener(dns_protocol::kTypePTR, "_privet._tcp.local",
1052                                    &delegate_irrelevant);
1053   listener_irrelevant->Start();
1054 
1055   SimulatePacketReceive(kSamplePacketNsec,
1056                         sizeof(kSamplePacketNsec));
1057 
1058   EXPECT_CALL(*this,
1059               MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL));
1060 
1061   std::unique_ptr<MDnsTransaction> transaction_privet_a =
1062       test_client_->CreateTransaction(
1063           dns_protocol::kTypeA, "_privet._tcp.local",
1064           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
1065               MDnsTransaction::SINGLE_RESULT,
1066           base::BindRepeating(&MDnsTest::MockableRecordCallback,
1067                               base::Unretained(this)));
1068 
1069   ASSERT_TRUE(transaction_privet_a->Start());
1070 
1071   // Test that a PTR transaction does NOT consider the same NSEC record to be a
1072   // valid answer to the query
1073 
1074   std::unique_ptr<MDnsTransaction> transaction_privet_ptr =
1075       test_client_->CreateTransaction(
1076           dns_protocol::kTypePTR, "_privet._tcp.local",
1077           MDnsTransaction::QUERY_NETWORK | MDnsTransaction::QUERY_CACHE |
1078               MDnsTransaction::SINGLE_RESULT,
1079           base::BindRepeating(&MDnsTest::MockableRecordCallback,
1080                               base::Unretained(this)));
1081 
1082   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
1083 
1084   ASSERT_TRUE(transaction_privet_ptr->Start());
1085 }
1086 
TEST_F(MDnsTest,NsecConflictRemoval)1087 TEST_F(MDnsTest, NsecConflictRemoval) {
1088   StrictMock<MockListenerDelegate> delegate_privet;
1089   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1090       dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
1091 
1092   ASSERT_TRUE(listener_privet->Start());
1093 
1094   const RecordParsed* record1;
1095   const RecordParsed* record2;
1096 
1097   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
1098       .WillOnce(SaveArg<1>(&record1));
1099 
1100   SimulatePacketReceive(kSamplePacketAPrivet,
1101                         sizeof(kSamplePacketAPrivet));
1102 
1103   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
1104       .WillOnce(SaveArg<1>(&record2));
1105 
1106   EXPECT_CALL(delegate_privet,
1107               OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA));
1108 
1109   SimulatePacketReceive(kSamplePacketNsec,
1110                         sizeof(kSamplePacketNsec));
1111 
1112   EXPECT_EQ(record1, record2);
1113 }
1114 
1115 
TEST_F(MDnsTest,RefreshQuery)1116 TEST_F(MDnsTest, RefreshQuery) {
1117   StrictMock<MockListenerDelegate> delegate_privet;
1118   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1119       dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
1120 
1121   listener_privet->SetActiveRefresh(true);
1122   ASSERT_TRUE(listener_privet->Start());
1123 
1124   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _));
1125 
1126   SimulatePacketReceive(kSamplePacketAPrivet,
1127                         sizeof(kSamplePacketAPrivet));
1128 
1129   // Expecting 2 calls (one for ipv4 and one for ipv6) for each of the 2
1130   // scheduled refresh queries.
1131   EXPECT_CALL(socket_factory_, OnSendTo(
1132       MakeString(kQueryPacketPrivetA, sizeof(kQueryPacketPrivetA))))
1133       .Times(4);
1134 
1135   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _));
1136 
1137   RunFor(base::TimeDelta::FromSeconds(6));
1138 }
1139 
1140 // MDnsSocketFactory implementation that creates a single socket that will
1141 // always fail on RecvFrom. Passing this to MdnsClient is expected to result in
1142 // the client failing to start listening.
1143 class FailingSocketFactory : public MDnsSocketFactory {
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)1144   void CreateSockets(
1145       std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override {
1146     auto socket =
1147         std::make_unique<MockMDnsDatagramServerSocket>(ADDRESS_FAMILY_IPV4);
1148     EXPECT_CALL(*socket, RecvFromInternal(_, _, _, _))
1149         .WillRepeatedly(Return(ERR_FAILED));
1150     sockets->push_back(std::move(socket));
1151   }
1152 };
1153 
TEST_F(MDnsTest,StartListeningFailure)1154 TEST_F(MDnsTest, StartListeningFailure) {
1155   test_client_ = std::make_unique<MDnsClientImpl>();
1156   FailingSocketFactory socket_factory;
1157 
1158   EXPECT_THAT(test_client_->StartListening(&socket_factory),
1159               test::IsError(ERR_FAILED));
1160 }
1161 
1162 // Test that the cache is cleared when it gets filled to unreasonable sizes.
TEST_F(MDnsTest,ClearOverfilledCache)1163 TEST_F(MDnsTest, ClearOverfilledCache) {
1164   test_client_->core()->cache_for_testing()->set_entry_limit_for_testing(1);
1165 
1166   StrictMock<MockListenerDelegate> delegate_privet;
1167   StrictMock<MockListenerDelegate> delegate_printer;
1168 
1169   PtrRecordCopyContainer record_privet;
1170   PtrRecordCopyContainer record_printer;
1171 
1172   std::unique_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
1173       dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
1174   std::unique_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
1175       dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
1176 
1177   ASSERT_TRUE(listener_privet->Start());
1178   ASSERT_TRUE(listener_printer->Start());
1179 
1180   bool privet_added = false;
1181   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
1182       .Times(AtMost(1))
1183       .WillOnce(Assign(&privet_added, true));
1184   EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
1185       .WillRepeatedly(Assign(&privet_added, false));
1186 
1187   bool printer_added = false;
1188   EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
1189       .Times(AtMost(1))
1190       .WillOnce(Assign(&printer_added, true));
1191   EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
1192       .WillRepeatedly(Assign(&printer_added, false));
1193 
1194   // Fill past capacity and expect everything to eventually be removed.
1195   SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
1196   base::RunLoop().RunUntilIdle();
1197   EXPECT_FALSE(privet_added);
1198   EXPECT_FALSE(printer_added);
1199 }
1200 
1201 // Note: These tests assume that the ipv4 socket will always be created first.
1202 // This is a simplifying assumption based on the way the code works now.
1203 class SimpleMockSocketFactory : public MDnsSocketFactory {
1204  public:
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)1205   void CreateSockets(
1206       std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override {
1207     sockets->clear();
1208     sockets->swap(sockets_);
1209   }
1210 
PushSocket(std::unique_ptr<DatagramServerSocket> socket)1211   void PushSocket(std::unique_ptr<DatagramServerSocket> socket) {
1212     sockets_.push_back(std::move(socket));
1213   }
1214 
1215  private:
1216   std::vector<std::unique_ptr<DatagramServerSocket>> sockets_;
1217 };
1218 
1219 class MockMDnsConnectionDelegate : public MDnsConnection::Delegate {
1220  public:
HandlePacket(DnsResponse * response,int size)1221   void HandlePacket(DnsResponse* response, int size) override {
1222     HandlePacketInternal(std::string(response->io_buffer()->data(), size));
1223   }
1224 
1225   MOCK_METHOD1(HandlePacketInternal, void(std::string packet));
1226 
1227   MOCK_METHOD1(OnConnectionError, void(int error));
1228 };
1229 
1230 class MDnsConnectionTest : public TestWithTaskEnvironment {
1231  public:
MDnsConnectionTest()1232   MDnsConnectionTest() : connection_(&delegate_) {
1233   }
1234 
1235  protected:
1236   // Follow successful connection initialization.
SetUp()1237   void SetUp() override {
1238     socket_ipv4_ = new MockMDnsDatagramServerSocket(ADDRESS_FAMILY_IPV4);
1239     socket_ipv6_ = new MockMDnsDatagramServerSocket(ADDRESS_FAMILY_IPV6);
1240     factory_.PushSocket(base::WrapUnique(socket_ipv6_));
1241     factory_.PushSocket(base::WrapUnique(socket_ipv4_));
1242     sample_packet_ = MakeString(kSamplePacket1, sizeof(kSamplePacket1));
1243     sample_buffer_ = base::MakeRefCounted<StringIOBuffer>(sample_packet_);
1244   }
1245 
InitConnection()1246   int InitConnection() { return connection_.Init(&factory_); }
1247 
1248   StrictMock<MockMDnsConnectionDelegate> delegate_;
1249 
1250   MockMDnsDatagramServerSocket* socket_ipv4_;
1251   MockMDnsDatagramServerSocket* socket_ipv6_;
1252   SimpleMockSocketFactory factory_;
1253   MDnsConnection connection_;
1254   TestCompletionCallback callback_;
1255   std::string sample_packet_;
1256   scoped_refptr<IOBuffer> sample_buffer_;
1257 };
1258 
TEST_F(MDnsConnectionTest,ReceiveSynchronous)1259 TEST_F(MDnsConnectionTest, ReceiveSynchronous) {
1260   socket_ipv6_->SetResponsePacket(sample_packet_);
1261   EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1262       .WillOnce(Return(ERR_IO_PENDING));
1263   EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1264       .WillOnce(
1265           Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvNow))
1266       .WillOnce(Return(ERR_IO_PENDING));
1267 
1268   EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet_));
1269   EXPECT_THAT(InitConnection(), test::IsOk());
1270 }
1271 
TEST_F(MDnsConnectionTest,ReceiveAsynchronous)1272 TEST_F(MDnsConnectionTest, ReceiveAsynchronous) {
1273   socket_ipv6_->SetResponsePacket(sample_packet_);
1274 
1275   EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1276       .WillOnce(Return(ERR_IO_PENDING));
1277   EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1278       .Times(2)
1279       .WillOnce(
1280           Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvLater))
1281       .WillOnce(Return(ERR_IO_PENDING));
1282 
1283   ASSERT_THAT(InitConnection(), test::IsOk());
1284 
1285   EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet_));
1286 
1287   base::RunLoop().RunUntilIdle();
1288 }
1289 
TEST_F(MDnsConnectionTest,Error)1290 TEST_F(MDnsConnectionTest, Error) {
1291   CompletionRepeatingCallback callback;
1292 
1293   EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1294       .WillOnce(Return(ERR_IO_PENDING));
1295   EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1296       .WillOnce(DoAll(SaveArg<3>(&callback), Return(ERR_IO_PENDING)));
1297 
1298   ASSERT_THAT(InitConnection(), test::IsOk());
1299 
1300   EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED));
1301   callback.Run(ERR_SOCKET_NOT_CONNECTED);
1302   base::RunLoop().RunUntilIdle();
1303 }
1304 
1305 class MDnsConnectionSendTest : public MDnsConnectionTest {
1306  protected:
SetUp()1307   void SetUp() override {
1308     MDnsConnectionTest::SetUp();
1309     EXPECT_CALL(*socket_ipv4_, RecvFromInternal(_, _, _, _))
1310         .WillOnce(Return(ERR_IO_PENDING));
1311     EXPECT_CALL(*socket_ipv6_, RecvFromInternal(_, _, _, _))
1312         .WillOnce(Return(ERR_IO_PENDING));
1313     EXPECT_THAT(InitConnection(), test::IsOk());
1314   }
1315 };
1316 
TEST_F(MDnsConnectionSendTest,Send)1317 TEST_F(MDnsConnectionSendTest, Send) {
1318   EXPECT_CALL(*socket_ipv4_,
1319               SendToInternal(sample_packet_, "224.0.0.251:5353", _));
1320   EXPECT_CALL(*socket_ipv6_,
1321               SendToInternal(sample_packet_, "[ff02::fb]:5353", _));
1322 
1323   connection_.Send(sample_buffer_, sample_packet_.size());
1324 }
1325 
TEST_F(MDnsConnectionSendTest,SendError)1326 TEST_F(MDnsConnectionSendTest, SendError) {
1327   EXPECT_CALL(*socket_ipv4_,
1328               SendToInternal(sample_packet_, "224.0.0.251:5353", _));
1329   EXPECT_CALL(*socket_ipv6_,
1330               SendToInternal(sample_packet_, "[ff02::fb]:5353", _))
1331       .WillOnce(Return(ERR_SOCKET_NOT_CONNECTED));
1332 
1333   connection_.Send(sample_buffer_, sample_packet_.size());
1334   EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED));
1335   base::RunLoop().RunUntilIdle();
1336 }
1337 
TEST_F(MDnsConnectionSendTest,SendQueued)1338 TEST_F(MDnsConnectionSendTest, SendQueued) {
1339   // Send data immediately.
1340   EXPECT_CALL(*socket_ipv4_,
1341               SendToInternal(sample_packet_, "224.0.0.251:5353", _))
1342       .Times(2)
1343       .WillRepeatedly(Return(OK));
1344 
1345   CompletionRepeatingCallback callback;
1346   // Delay sending data. Only the first call should be made.
1347   EXPECT_CALL(*socket_ipv6_,
1348               SendToInternal(sample_packet_, "[ff02::fb]:5353", _))
1349       .WillOnce(DoAll(SaveArg<2>(&callback), Return(ERR_IO_PENDING)));
1350 
1351   connection_.Send(sample_buffer_, sample_packet_.size());
1352   connection_.Send(sample_buffer_, sample_packet_.size());
1353 
1354   // The second IPv6 packet is not sent yet.
1355   EXPECT_CALL(*socket_ipv4_,
1356               SendToInternal(sample_packet_, "224.0.0.251:5353", _))
1357       .Times(0);
1358   // Expect call for the second IPv6 packet.
1359   EXPECT_CALL(*socket_ipv6_,
1360               SendToInternal(sample_packet_, "[ff02::fb]:5353", _))
1361       .WillOnce(Return(OK));
1362   callback.Run(OK);
1363 }
1364 
TEST(MDnsSocketTest,CreateSocket)1365 TEST(MDnsSocketTest, CreateSocket) {
1366   // Verifies that socket creation hasn't been broken.
1367   auto socket = CreateAndBindMDnsSocket(AddressFamily::ADDRESS_FAMILY_IPV4, 1,
1368                                         net::NetLog::Get());
1369   EXPECT_TRUE(socket);
1370   socket->Close();
1371 }
1372 
1373 }  // namespace net
1374