1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/local_discovery/service_discovery_client_mdns.h"
6 
7 #include <stddef.h>
8 
9 #include <utility>
10 #include <vector>
11 
12 #include "base/bind.h"
13 #include "base/location.h"
14 #include "base/macros.h"
15 #include "base/task/thread_pool.h"
16 #include "base/threading/thread_task_runner_handle.h"
17 #include "chrome/browser/browser_process.h"
18 #include "chrome/browser/local_discovery/service_discovery_client_impl.h"
19 #include "content/public/browser/browser_task_traits.h"
20 #include "content/public/browser/browser_thread.h"
21 #include "content/public/browser/network_service_instance.h"
22 #include "net/base/net_errors.h"
23 #include "net/socket/datagram_server_socket.h"
24 
25 namespace net {
26 class IPAddress;
27 }
28 
29 namespace local_discovery {
30 
31 using content::BrowserThread;
32 
33 // Base class for objects returned by ServiceDiscoveryClient implementation.
34 // Handles interaction of client code on UI thread end net code on mdns thread.
35 class ServiceDiscoveryClientMdns::Proxy {
36  public:
37   using WeakPtr = base::WeakPtr<Proxy>;
38 
Proxy(ServiceDiscoveryClientMdns * client)39   explicit Proxy(ServiceDiscoveryClientMdns* client) : client_(client) {
40     DCHECK_CURRENTLY_ON(BrowserThread::UI);
41     client_->proxies_.AddObserver(this);
42   }
43 
~Proxy()44   virtual ~Proxy() {
45     DCHECK_CURRENTLY_ON(BrowserThread::UI);
46     client_->proxies_.RemoveObserver(this);
47   }
48 
49   // Returns true if object is not yet shutdown.
50   virtual bool IsValid() = 0;
51 
52   // Notifies proxies that mDNS layer is going to be destroyed.
53   virtual void OnMdnsDestroy() = 0;
54 
55   // Notifies proxies that new mDNS instance is ready.
OnNewMdnsReady()56   virtual void OnNewMdnsReady() {
57     DCHECK(!client_->need_delay_mdns_tasks_);
58     if (IsValid()) {
59       for (auto& task : delayed_tasks_)
60         client_->mdns_runner_->PostTask(FROM_HERE, std::move(task));
61     }
62     delayed_tasks_.clear();
63   }
64 
65   // Runs callback using this method to abort callback if instance of |Proxy|
66   // is deleted.
RunCallback(base::OnceClosure callback)67   void RunCallback(base::OnceClosure callback) {
68     DCHECK_CURRENTLY_ON(BrowserThread::UI);
69     std::move(callback).Run();
70   }
71 
72  protected:
PostToMdnsThread(base::OnceClosure task)73   void PostToMdnsThread(base::OnceClosure task) {
74     DCHECK(IsValid());
75     // The first task on the IO thread for each |mdns_| instance must be
76     // InitMdns(). OnInterfaceListReady() could be delayed by
77     // GetMDnsInterfacesToBind() running on a background task runner, so
78     // PostToMdnsThread() could be called to post task for |mdns_| that is not
79     // initialized yet.
80     if (!client_->need_delay_mdns_tasks_) {
81       client_->mdns_runner_->PostTask(FROM_HERE, std::move(task));
82       return;
83     }
84     delayed_tasks_.emplace_back(std::move(task));
85   }
86 
PostToUIThread(base::OnceClosure task)87   static bool PostToUIThread(base::OnceClosure task) {
88     return content::GetUIThreadTaskRunner({})->PostTask(FROM_HERE,
89                                                         std::move(task));
90   }
91 
client()92   ServiceDiscoveryClient* client() {
93     return client_->client_.get();
94   }
95 
GetWeakPtr()96   WeakPtr GetWeakPtr() {
97     return weak_ptr_factory_.GetWeakPtr();
98   }
99 
100   template<class T>
DeleteOnMdnsThread(T * t)101   void DeleteOnMdnsThread(T* t) {
102     if (!t)
103       return;
104     if (!client_->mdns_runner_->DeleteSoon(FROM_HERE, t))
105       delete t;
106   }
107 
108  private:
109   scoped_refptr<ServiceDiscoveryClientMdns> client_;
110   // Delayed |mdns_runner_| tasks.
111   std::vector<base::OnceClosure> delayed_tasks_;
112   base::WeakPtrFactory<Proxy> weak_ptr_factory_{this};
113 
114   DISALLOW_COPY_AND_ASSIGN(Proxy);
115 };
116 
117 namespace {
118 
119 const int kMaxRestartAttempts = 10;
120 const int kRestartDelayOnNetworkChangeSeconds = 3;
121 
122 using MdnsInitCallback = base::OnceCallback<void(int)>;
123 
124 class SocketFactory : public net::MDnsSocketFactory {
125  public:
SocketFactory(const net::InterfaceIndexFamilyList & interfaces)126   explicit SocketFactory(const net::InterfaceIndexFamilyList& interfaces)
127       : interfaces_(interfaces) {}
128 
129   // net::MDnsSocketFactory implementation:
CreateSockets(std::vector<std::unique_ptr<net::DatagramServerSocket>> * sockets)130   void CreateSockets(std::vector<std::unique_ptr<net::DatagramServerSocket>>*
131                          sockets) override {
132     for (size_t i = 0; i < interfaces_.size(); ++i) {
133       DCHECK(interfaces_[i].second == net::ADDRESS_FAMILY_IPV4 ||
134              interfaces_[i].second == net::ADDRESS_FAMILY_IPV6);
135       std::unique_ptr<net::DatagramServerSocket> socket(CreateAndBindMDnsSocket(
136           interfaces_[i].second, interfaces_[i].first, nullptr /* net_log */));
137       if (socket)
138         sockets->push_back(std::move(socket));
139     }
140   }
141 
142  private:
143   net::InterfaceIndexFamilyList interfaces_;
144 
145   DISALLOW_COPY_AND_ASSIGN(SocketFactory);
146 };
147 
InitMdns(MdnsInitCallback on_initialized,const net::InterfaceIndexFamilyList & interfaces,net::MDnsClient * mdns)148 void InitMdns(MdnsInitCallback on_initialized,
149               const net::InterfaceIndexFamilyList& interfaces,
150               net::MDnsClient* mdns) {
151   SocketFactory socket_factory(interfaces);
152   content::GetUIThreadTaskRunner({})->PostTask(
153       FROM_HERE, base::BindOnce(std::move(on_initialized),
154                                 mdns->StartListening(&socket_factory)));
155 }
156 
157 template<class T>
158 class ProxyBase : public ServiceDiscoveryClientMdns::Proxy, public T {
159  public:
160   using Base = ProxyBase<T>;
161 
ProxyBase(ServiceDiscoveryClientMdns * client)162   explicit ProxyBase(ServiceDiscoveryClientMdns* client)
163       : Proxy(client) {
164   }
165 
~ProxyBase()166   ~ProxyBase() override {
167     DeleteOnMdnsThread(implementation_.release());
168   }
169 
IsValid()170   bool IsValid() override {
171     return !!implementation();
172   }
173 
OnMdnsDestroy()174   void OnMdnsDestroy() override {
175     DeleteOnMdnsThread(implementation_.release());
176   }
177 
178  protected:
set_implementation(std::unique_ptr<T> implementation)179   void set_implementation(std::unique_ptr<T> implementation) {
180     implementation_ = std::move(implementation);
181   }
182 
implementation() const183   T* implementation()  const {
184     return implementation_.get();
185   }
186 
187  private:
188   std::unique_ptr<T> implementation_;
189 
190   DISALLOW_COPY_AND_ASSIGN(ProxyBase);
191 };
192 
193 class ServiceWatcherProxy : public ProxyBase<ServiceWatcher> {
194  public:
ServiceWatcherProxy(ServiceDiscoveryClientMdns * client_mdns,const std::string & service_type,ServiceWatcher::UpdatedCallback callback)195   ServiceWatcherProxy(ServiceDiscoveryClientMdns* client_mdns,
196                       const std::string& service_type,
197                       ServiceWatcher::UpdatedCallback callback)
198       : ProxyBase(client_mdns),
199         service_type_(service_type),
200         callback_(callback) {
201     // It's safe to call |CreateServiceWatcher| on UI thread, because
202     // |MDnsClient| is not used there. It's simplify implementation.
203     set_implementation(client()->CreateServiceWatcher(
204         service_type, base::BindRepeating(&ServiceWatcherProxy::OnCallback,
205                                           GetWeakPtr(), std::move(callback))));
206   }
207 
208   // ServiceWatcher methods.
Start()209   void Start() override {
210     if (implementation()) {
211       PostToMdnsThread(base::BindOnce(&ServiceWatcher::Start,
212                                       base::Unretained(implementation())));
213     }
214   }
215 
DiscoverNewServices()216   void DiscoverNewServices() override {
217     if (implementation()) {
218       PostToMdnsThread(base::BindOnce(&ServiceWatcher::DiscoverNewServices,
219                                       base::Unretained(implementation())));
220     }
221   }
222 
SetActivelyRefreshServices(bool actively_refresh_services)223   void SetActivelyRefreshServices(bool actively_refresh_services) override {
224     if (implementation()) {
225       PostToMdnsThread(base::BindOnce(
226           &ServiceWatcher::SetActivelyRefreshServices,
227           base::Unretained(implementation()), actively_refresh_services));
228     }
229   }
230 
GetServiceType() const231   std::string GetServiceType() const override { return service_type_; }
232 
OnNewMdnsReady()233   void OnNewMdnsReady() override {
234     ProxyBase<ServiceWatcher>::OnNewMdnsReady();
235     if (!implementation())
236       callback_.Run(ServiceWatcher::UPDATE_INVALIDATED, "");
237   }
238 
239  private:
OnCallback(const WeakPtr & proxy,ServiceWatcher::UpdatedCallback callback,UpdateType a1,const std::string & a2)240   static void OnCallback(const WeakPtr& proxy,
241                          ServiceWatcher::UpdatedCallback callback,
242                          UpdateType a1,
243                          const std::string& a2) {
244     DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
245     PostToUIThread(base::BindOnce(&Base::RunCallback, proxy,
246                                   base::BindOnce(std::move(callback), a1, a2)));
247   }
248   std::string service_type_;
249   ServiceWatcher::UpdatedCallback callback_;
250 
251   DISALLOW_COPY_AND_ASSIGN(ServiceWatcherProxy);
252 };
253 
254 class ServiceResolverProxy : public ProxyBase<ServiceResolver> {
255  public:
ServiceResolverProxy(ServiceDiscoveryClientMdns * client_mdns,const std::string & service_name,ServiceResolver::ResolveCompleteCallback callback)256   ServiceResolverProxy(ServiceDiscoveryClientMdns* client_mdns,
257                        const std::string& service_name,
258                        ServiceResolver::ResolveCompleteCallback callback)
259       : ProxyBase(client_mdns), service_name_(service_name) {
260     // It's safe to call |CreateServiceResolver| on UI thread, because
261     // |MDnsClient| is not used there. It's simplify implementation.
262     set_implementation(client()->CreateServiceResolver(
263         service_name, base::BindOnce(&ServiceResolverProxy::OnCallback,
264                                      GetWeakPtr(), std::move(callback))));
265   }
266 
267   // ServiceResolver methods.
StartResolving()268   void StartResolving() override {
269     if (implementation()) {
270       PostToMdnsThread(base::BindOnce(&ServiceResolver::StartResolving,
271                                       base::Unretained(implementation())));
272     }
273   }
274 
GetName() const275   std::string GetName() const override { return service_name_; }
276 
277  private:
OnCallback(const WeakPtr & proxy,ServiceResolver::ResolveCompleteCallback callback,RequestStatus a1,const ServiceDescription & a2)278   static void OnCallback(const WeakPtr& proxy,
279                          ServiceResolver::ResolveCompleteCallback callback,
280                          RequestStatus a1,
281                          const ServiceDescription& a2) {
282     DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
283     PostToUIThread(base::BindOnce(&Base::RunCallback, proxy,
284                                   base::BindOnce(std::move(callback), a1, a2)));
285   }
286 
287   std::string service_name_;
288 
289   DISALLOW_COPY_AND_ASSIGN(ServiceResolverProxy);
290 };
291 
292 class LocalDomainResolverProxy : public ProxyBase<LocalDomainResolver> {
293  public:
LocalDomainResolverProxy(ServiceDiscoveryClientMdns * client_mdns,const std::string & domain,net::AddressFamily address_family,LocalDomainResolver::IPAddressCallback callback)294   LocalDomainResolverProxy(ServiceDiscoveryClientMdns* client_mdns,
295                            const std::string& domain,
296                            net::AddressFamily address_family,
297                            LocalDomainResolver::IPAddressCallback callback)
298       : ProxyBase(client_mdns) {
299     // It's safe to call |CreateLocalDomainResolver| on UI thread, because
300     // |MDnsClient| is not used there. It's simplify implementation.
301     set_implementation(client()->CreateLocalDomainResolver(
302         domain, address_family,
303         base::BindOnce(&LocalDomainResolverProxy::OnCallback, GetWeakPtr(),
304                        std::move(callback))));
305   }
306 
307   // LocalDomainResolver methods.
Start()308   void Start() override {
309     if (implementation()) {
310       PostToMdnsThread(base::BindOnce(&LocalDomainResolver::Start,
311                                       base::Unretained(implementation())));
312     }
313   }
314 
315  private:
OnCallback(const WeakPtr & proxy,LocalDomainResolver::IPAddressCallback callback,bool a1,const net::IPAddress & a2,const net::IPAddress & a3)316   static void OnCallback(const WeakPtr& proxy,
317                          LocalDomainResolver::IPAddressCallback callback,
318                          bool a1,
319                          const net::IPAddress& a2,
320                          const net::IPAddress& a3) {
321     DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
322     PostToUIThread(
323         base::BindOnce(&Base::RunCallback, proxy,
324                        base::BindOnce(std::move(callback), a1, a2, a3)));
325   }
326 
327   DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverProxy);
328 };
329 
330 }  // namespace
331 
ServiceDiscoveryClientMdns()332 ServiceDiscoveryClientMdns::ServiceDiscoveryClientMdns()
333     : mdns_runner_(content::GetIOThreadTaskRunner({})) {
334   DCHECK_CURRENTLY_ON(BrowserThread::UI);
335   content::GetNetworkConnectionTracker()->AddNetworkConnectionObserver(this);
336   StartNewClient();
337 }
338 
339 std::unique_ptr<ServiceWatcher>
CreateServiceWatcher(const std::string & service_type,ServiceWatcher::UpdatedCallback callback)340 ServiceDiscoveryClientMdns::CreateServiceWatcher(
341     const std::string& service_type,
342     ServiceWatcher::UpdatedCallback callback) {
343   DCHECK_CURRENTLY_ON(BrowserThread::UI);
344   return std::make_unique<ServiceWatcherProxy>(this, service_type,
345                                                std::move(callback));
346 }
347 
348 std::unique_ptr<ServiceResolver>
CreateServiceResolver(const std::string & service_name,ServiceResolver::ResolveCompleteCallback callback)349 ServiceDiscoveryClientMdns::CreateServiceResolver(
350     const std::string& service_name,
351     ServiceResolver::ResolveCompleteCallback callback) {
352   DCHECK_CURRENTLY_ON(BrowserThread::UI);
353   return std::make_unique<ServiceResolverProxy>(this, service_name,
354                                                 std::move(callback));
355 }
356 
357 std::unique_ptr<LocalDomainResolver>
CreateLocalDomainResolver(const std::string & domain,net::AddressFamily address_family,LocalDomainResolver::IPAddressCallback callback)358 ServiceDiscoveryClientMdns::CreateLocalDomainResolver(
359     const std::string& domain,
360     net::AddressFamily address_family,
361     LocalDomainResolver::IPAddressCallback callback) {
362   DCHECK_CURRENTLY_ON(BrowserThread::UI);
363   return std::make_unique<LocalDomainResolverProxy>(
364       this, domain, address_family, std::move(callback));
365 }
366 
~ServiceDiscoveryClientMdns()367 ServiceDiscoveryClientMdns::~ServiceDiscoveryClientMdns() {
368   DCHECK_CURRENTLY_ON(BrowserThread::UI);
369   content::GetNetworkConnectionTracker()->RemoveNetworkConnectionObserver(this);
370   DestroyMdns();
371 }
372 
OnConnectionChanged(network::mojom::ConnectionType type)373 void ServiceDiscoveryClientMdns::OnConnectionChanged(
374     network::mojom::ConnectionType type) {
375   DCHECK_CURRENTLY_ON(BrowserThread::UI);
376   // Only network changes resets counter.
377   restart_attempts_ = 0;
378   ScheduleStartNewClient();
379 }
380 
ScheduleStartNewClient()381 void ServiceDiscoveryClientMdns::ScheduleStartNewClient() {
382   DCHECK_CURRENTLY_ON(BrowserThread::UI);
383   OnBeforeMdnsDestroy();
384   if (restart_attempts_ >= kMaxRestartAttempts)
385     return;
386 
387   base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
388       FROM_HERE,
389       base::BindOnce(&ServiceDiscoveryClientMdns::StartNewClient,
390                      weak_ptr_factory_.GetWeakPtr()),
391       base::TimeDelta::FromSeconds(kRestartDelayOnNetworkChangeSeconds *
392                                    (1 << restart_attempts_)));
393 }
394 
StartNewClient()395 void ServiceDiscoveryClientMdns::StartNewClient() {
396   DCHECK_CURRENTLY_ON(BrowserThread::UI);
397   ++restart_attempts_;
398   DestroyMdns();
399   mdns_ = net::MDnsClient::CreateDefault();
400   client_ = std::make_unique<ServiceDiscoveryClientImpl>(mdns_.get());
401   base::ThreadPool::PostTaskAndReplyWithResult(
402       FROM_HERE, {base::TaskPriority::BEST_EFFORT, base::MayBlock()},
403       base::BindOnce(&net::GetMDnsInterfacesToBind),
404       base::BindOnce(&ServiceDiscoveryClientMdns::OnInterfaceListReady,
405                      weak_ptr_factory_.GetWeakPtr()));
406 }
407 
OnInterfaceListReady(const net::InterfaceIndexFamilyList & interfaces)408 void ServiceDiscoveryClientMdns::OnInterfaceListReady(
409     const net::InterfaceIndexFamilyList& interfaces) {
410   DCHECK_CURRENTLY_ON(BrowserThread::UI);
411   mdns_runner_->PostTask(
412       FROM_HERE,
413       base::BindOnce(
414           &InitMdns,
415           base::BindOnce(&ServiceDiscoveryClientMdns::OnMdnsInitialized,
416                          weak_ptr_factory_.GetWeakPtr()),
417           interfaces, base::Unretained(mdns_.get())));
418 }
419 
OnMdnsInitialized(int net_error)420 void ServiceDiscoveryClientMdns::OnMdnsInitialized(int net_error) {
421   DCHECK_CURRENTLY_ON(BrowserThread::UI);
422   if (net_error != net::OK) {
423     ScheduleStartNewClient();
424     return;
425   }
426 
427   // Initialization is done, no need to delay tasks.
428   need_delay_mdns_tasks_ = false;
429   for (Proxy& observer : proxies_)
430     observer.OnNewMdnsReady();
431 }
432 
OnBeforeMdnsDestroy()433 void ServiceDiscoveryClientMdns::OnBeforeMdnsDestroy() {
434   need_delay_mdns_tasks_ = true;
435   weak_ptr_factory_.InvalidateWeakPtrs();
436   for (Proxy& observer : proxies_)
437     observer.OnMdnsDestroy();
438 }
439 
DestroyMdns()440 void ServiceDiscoveryClientMdns::DestroyMdns() {
441   OnBeforeMdnsDestroy();
442   // After calling Proxy::OnMdnsDestroy(), all references to |client_| and
443   // |mdns_| should be destroyed.
444   if (client_)
445     mdns_runner_->DeleteSoon(FROM_HERE, client_.release());
446   if (mdns_)
447     mdns_runner_->DeleteSoon(FROM_HERE, mdns_.release());
448 }
449 
450 }  // namespace local_discovery
451