1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/local_discovery/service_discovery_client_mdns.h"
6
7 #include <stddef.h>
8
9 #include <utility>
10 #include <vector>
11
12 #include "base/bind.h"
13 #include "base/location.h"
14 #include "base/macros.h"
15 #include "base/task/thread_pool.h"
16 #include "base/threading/thread_task_runner_handle.h"
17 #include "chrome/browser/browser_process.h"
18 #include "chrome/browser/local_discovery/service_discovery_client_impl.h"
19 #include "content/public/browser/browser_task_traits.h"
20 #include "content/public/browser/browser_thread.h"
21 #include "content/public/browser/network_service_instance.h"
22 #include "net/base/net_errors.h"
23 #include "net/socket/datagram_server_socket.h"
24
25 namespace net {
26 class IPAddress;
27 }
28
29 namespace local_discovery {
30
31 using content::BrowserThread;
32
33 // Base class for objects returned by ServiceDiscoveryClient implementation.
34 // Handles interaction of client code on UI thread end net code on mdns thread.
35 class ServiceDiscoveryClientMdns::Proxy {
36 public:
37 using WeakPtr = base::WeakPtr<Proxy>;
38
Proxy(ServiceDiscoveryClientMdns * client)39 explicit Proxy(ServiceDiscoveryClientMdns* client) : client_(client) {
40 DCHECK_CURRENTLY_ON(BrowserThread::UI);
41 client_->proxies_.AddObserver(this);
42 }
43
~Proxy()44 virtual ~Proxy() {
45 DCHECK_CURRENTLY_ON(BrowserThread::UI);
46 client_->proxies_.RemoveObserver(this);
47 }
48
49 // Returns true if object is not yet shutdown.
50 virtual bool IsValid() = 0;
51
52 // Notifies proxies that mDNS layer is going to be destroyed.
53 virtual void OnMdnsDestroy() = 0;
54
55 // Notifies proxies that new mDNS instance is ready.
OnNewMdnsReady()56 virtual void OnNewMdnsReady() {
57 DCHECK(!client_->need_delay_mdns_tasks_);
58 if (IsValid()) {
59 for (auto& task : delayed_tasks_)
60 client_->mdns_runner_->PostTask(FROM_HERE, std::move(task));
61 }
62 delayed_tasks_.clear();
63 }
64
65 // Runs callback using this method to abort callback if instance of |Proxy|
66 // is deleted.
RunCallback(base::OnceClosure callback)67 void RunCallback(base::OnceClosure callback) {
68 DCHECK_CURRENTLY_ON(BrowserThread::UI);
69 std::move(callback).Run();
70 }
71
72 protected:
PostToMdnsThread(base::OnceClosure task)73 void PostToMdnsThread(base::OnceClosure task) {
74 DCHECK(IsValid());
75 // The first task on the IO thread for each |mdns_| instance must be
76 // InitMdns(). OnInterfaceListReady() could be delayed by
77 // GetMDnsInterfacesToBind() running on a background task runner, so
78 // PostToMdnsThread() could be called to post task for |mdns_| that is not
79 // initialized yet.
80 if (!client_->need_delay_mdns_tasks_) {
81 client_->mdns_runner_->PostTask(FROM_HERE, std::move(task));
82 return;
83 }
84 delayed_tasks_.emplace_back(std::move(task));
85 }
86
PostToUIThread(base::OnceClosure task)87 static bool PostToUIThread(base::OnceClosure task) {
88 return content::GetUIThreadTaskRunner({})->PostTask(FROM_HERE,
89 std::move(task));
90 }
91
client()92 ServiceDiscoveryClient* client() {
93 return client_->client_.get();
94 }
95
GetWeakPtr()96 WeakPtr GetWeakPtr() {
97 return weak_ptr_factory_.GetWeakPtr();
98 }
99
100 template<class T>
DeleteOnMdnsThread(T * t)101 void DeleteOnMdnsThread(T* t) {
102 if (!t)
103 return;
104 if (!client_->mdns_runner_->DeleteSoon(FROM_HERE, t))
105 delete t;
106 }
107
108 private:
109 scoped_refptr<ServiceDiscoveryClientMdns> client_;
110 // Delayed |mdns_runner_| tasks.
111 std::vector<base::OnceClosure> delayed_tasks_;
112 base::WeakPtrFactory<Proxy> weak_ptr_factory_{this};
113
114 DISALLOW_COPY_AND_ASSIGN(Proxy);
115 };
116
117 namespace {
118
119 const int kMaxRestartAttempts = 10;
120 const int kRestartDelayOnNetworkChangeSeconds = 3;
121
122 using MdnsInitCallback = base::OnceCallback<void(int)>;
123
124 class SocketFactory : public net::MDnsSocketFactory {
125 public:
SocketFactory(const net::InterfaceIndexFamilyList & interfaces)126 explicit SocketFactory(const net::InterfaceIndexFamilyList& interfaces)
127 : interfaces_(interfaces) {}
128
129 // net::MDnsSocketFactory implementation:
CreateSockets(std::vector<std::unique_ptr<net::DatagramServerSocket>> * sockets)130 void CreateSockets(std::vector<std::unique_ptr<net::DatagramServerSocket>>*
131 sockets) override {
132 for (size_t i = 0; i < interfaces_.size(); ++i) {
133 DCHECK(interfaces_[i].second == net::ADDRESS_FAMILY_IPV4 ||
134 interfaces_[i].second == net::ADDRESS_FAMILY_IPV6);
135 std::unique_ptr<net::DatagramServerSocket> socket(CreateAndBindMDnsSocket(
136 interfaces_[i].second, interfaces_[i].first, nullptr /* net_log */));
137 if (socket)
138 sockets->push_back(std::move(socket));
139 }
140 }
141
142 private:
143 net::InterfaceIndexFamilyList interfaces_;
144
145 DISALLOW_COPY_AND_ASSIGN(SocketFactory);
146 };
147
InitMdns(MdnsInitCallback on_initialized,const net::InterfaceIndexFamilyList & interfaces,net::MDnsClient * mdns)148 void InitMdns(MdnsInitCallback on_initialized,
149 const net::InterfaceIndexFamilyList& interfaces,
150 net::MDnsClient* mdns) {
151 SocketFactory socket_factory(interfaces);
152 content::GetUIThreadTaskRunner({})->PostTask(
153 FROM_HERE, base::BindOnce(std::move(on_initialized),
154 mdns->StartListening(&socket_factory)));
155 }
156
157 template<class T>
158 class ProxyBase : public ServiceDiscoveryClientMdns::Proxy, public T {
159 public:
160 using Base = ProxyBase<T>;
161
ProxyBase(ServiceDiscoveryClientMdns * client)162 explicit ProxyBase(ServiceDiscoveryClientMdns* client)
163 : Proxy(client) {
164 }
165
~ProxyBase()166 ~ProxyBase() override {
167 DeleteOnMdnsThread(implementation_.release());
168 }
169
IsValid()170 bool IsValid() override {
171 return !!implementation();
172 }
173
OnMdnsDestroy()174 void OnMdnsDestroy() override {
175 DeleteOnMdnsThread(implementation_.release());
176 }
177
178 protected:
set_implementation(std::unique_ptr<T> implementation)179 void set_implementation(std::unique_ptr<T> implementation) {
180 implementation_ = std::move(implementation);
181 }
182
implementation() const183 T* implementation() const {
184 return implementation_.get();
185 }
186
187 private:
188 std::unique_ptr<T> implementation_;
189
190 DISALLOW_COPY_AND_ASSIGN(ProxyBase);
191 };
192
193 class ServiceWatcherProxy : public ProxyBase<ServiceWatcher> {
194 public:
ServiceWatcherProxy(ServiceDiscoveryClientMdns * client_mdns,const std::string & service_type,ServiceWatcher::UpdatedCallback callback)195 ServiceWatcherProxy(ServiceDiscoveryClientMdns* client_mdns,
196 const std::string& service_type,
197 ServiceWatcher::UpdatedCallback callback)
198 : ProxyBase(client_mdns),
199 service_type_(service_type),
200 callback_(callback) {
201 // It's safe to call |CreateServiceWatcher| on UI thread, because
202 // |MDnsClient| is not used there. It's simplify implementation.
203 set_implementation(client()->CreateServiceWatcher(
204 service_type, base::BindRepeating(&ServiceWatcherProxy::OnCallback,
205 GetWeakPtr(), std::move(callback))));
206 }
207
208 // ServiceWatcher methods.
Start()209 void Start() override {
210 if (implementation()) {
211 PostToMdnsThread(base::BindOnce(&ServiceWatcher::Start,
212 base::Unretained(implementation())));
213 }
214 }
215
DiscoverNewServices()216 void DiscoverNewServices() override {
217 if (implementation()) {
218 PostToMdnsThread(base::BindOnce(&ServiceWatcher::DiscoverNewServices,
219 base::Unretained(implementation())));
220 }
221 }
222
SetActivelyRefreshServices(bool actively_refresh_services)223 void SetActivelyRefreshServices(bool actively_refresh_services) override {
224 if (implementation()) {
225 PostToMdnsThread(base::BindOnce(
226 &ServiceWatcher::SetActivelyRefreshServices,
227 base::Unretained(implementation()), actively_refresh_services));
228 }
229 }
230
GetServiceType() const231 std::string GetServiceType() const override { return service_type_; }
232
OnNewMdnsReady()233 void OnNewMdnsReady() override {
234 ProxyBase<ServiceWatcher>::OnNewMdnsReady();
235 if (!implementation())
236 callback_.Run(ServiceWatcher::UPDATE_INVALIDATED, "");
237 }
238
239 private:
OnCallback(const WeakPtr & proxy,ServiceWatcher::UpdatedCallback callback,UpdateType a1,const std::string & a2)240 static void OnCallback(const WeakPtr& proxy,
241 ServiceWatcher::UpdatedCallback callback,
242 UpdateType a1,
243 const std::string& a2) {
244 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
245 PostToUIThread(base::BindOnce(&Base::RunCallback, proxy,
246 base::BindOnce(std::move(callback), a1, a2)));
247 }
248 std::string service_type_;
249 ServiceWatcher::UpdatedCallback callback_;
250
251 DISALLOW_COPY_AND_ASSIGN(ServiceWatcherProxy);
252 };
253
254 class ServiceResolverProxy : public ProxyBase<ServiceResolver> {
255 public:
ServiceResolverProxy(ServiceDiscoveryClientMdns * client_mdns,const std::string & service_name,ServiceResolver::ResolveCompleteCallback callback)256 ServiceResolverProxy(ServiceDiscoveryClientMdns* client_mdns,
257 const std::string& service_name,
258 ServiceResolver::ResolveCompleteCallback callback)
259 : ProxyBase(client_mdns), service_name_(service_name) {
260 // It's safe to call |CreateServiceResolver| on UI thread, because
261 // |MDnsClient| is not used there. It's simplify implementation.
262 set_implementation(client()->CreateServiceResolver(
263 service_name, base::BindOnce(&ServiceResolverProxy::OnCallback,
264 GetWeakPtr(), std::move(callback))));
265 }
266
267 // ServiceResolver methods.
StartResolving()268 void StartResolving() override {
269 if (implementation()) {
270 PostToMdnsThread(base::BindOnce(&ServiceResolver::StartResolving,
271 base::Unretained(implementation())));
272 }
273 }
274
GetName() const275 std::string GetName() const override { return service_name_; }
276
277 private:
OnCallback(const WeakPtr & proxy,ServiceResolver::ResolveCompleteCallback callback,RequestStatus a1,const ServiceDescription & a2)278 static void OnCallback(const WeakPtr& proxy,
279 ServiceResolver::ResolveCompleteCallback callback,
280 RequestStatus a1,
281 const ServiceDescription& a2) {
282 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
283 PostToUIThread(base::BindOnce(&Base::RunCallback, proxy,
284 base::BindOnce(std::move(callback), a1, a2)));
285 }
286
287 std::string service_name_;
288
289 DISALLOW_COPY_AND_ASSIGN(ServiceResolverProxy);
290 };
291
292 class LocalDomainResolverProxy : public ProxyBase<LocalDomainResolver> {
293 public:
LocalDomainResolverProxy(ServiceDiscoveryClientMdns * client_mdns,const std::string & domain,net::AddressFamily address_family,LocalDomainResolver::IPAddressCallback callback)294 LocalDomainResolverProxy(ServiceDiscoveryClientMdns* client_mdns,
295 const std::string& domain,
296 net::AddressFamily address_family,
297 LocalDomainResolver::IPAddressCallback callback)
298 : ProxyBase(client_mdns) {
299 // It's safe to call |CreateLocalDomainResolver| on UI thread, because
300 // |MDnsClient| is not used there. It's simplify implementation.
301 set_implementation(client()->CreateLocalDomainResolver(
302 domain, address_family,
303 base::BindOnce(&LocalDomainResolverProxy::OnCallback, GetWeakPtr(),
304 std::move(callback))));
305 }
306
307 // LocalDomainResolver methods.
Start()308 void Start() override {
309 if (implementation()) {
310 PostToMdnsThread(base::BindOnce(&LocalDomainResolver::Start,
311 base::Unretained(implementation())));
312 }
313 }
314
315 private:
OnCallback(const WeakPtr & proxy,LocalDomainResolver::IPAddressCallback callback,bool a1,const net::IPAddress & a2,const net::IPAddress & a3)316 static void OnCallback(const WeakPtr& proxy,
317 LocalDomainResolver::IPAddressCallback callback,
318 bool a1,
319 const net::IPAddress& a2,
320 const net::IPAddress& a3) {
321 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
322 PostToUIThread(
323 base::BindOnce(&Base::RunCallback, proxy,
324 base::BindOnce(std::move(callback), a1, a2, a3)));
325 }
326
327 DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverProxy);
328 };
329
330 } // namespace
331
ServiceDiscoveryClientMdns()332 ServiceDiscoveryClientMdns::ServiceDiscoveryClientMdns()
333 : mdns_runner_(content::GetIOThreadTaskRunner({})) {
334 DCHECK_CURRENTLY_ON(BrowserThread::UI);
335 content::GetNetworkConnectionTracker()->AddNetworkConnectionObserver(this);
336 StartNewClient();
337 }
338
339 std::unique_ptr<ServiceWatcher>
CreateServiceWatcher(const std::string & service_type,ServiceWatcher::UpdatedCallback callback)340 ServiceDiscoveryClientMdns::CreateServiceWatcher(
341 const std::string& service_type,
342 ServiceWatcher::UpdatedCallback callback) {
343 DCHECK_CURRENTLY_ON(BrowserThread::UI);
344 return std::make_unique<ServiceWatcherProxy>(this, service_type,
345 std::move(callback));
346 }
347
348 std::unique_ptr<ServiceResolver>
CreateServiceResolver(const std::string & service_name,ServiceResolver::ResolveCompleteCallback callback)349 ServiceDiscoveryClientMdns::CreateServiceResolver(
350 const std::string& service_name,
351 ServiceResolver::ResolveCompleteCallback callback) {
352 DCHECK_CURRENTLY_ON(BrowserThread::UI);
353 return std::make_unique<ServiceResolverProxy>(this, service_name,
354 std::move(callback));
355 }
356
357 std::unique_ptr<LocalDomainResolver>
CreateLocalDomainResolver(const std::string & domain,net::AddressFamily address_family,LocalDomainResolver::IPAddressCallback callback)358 ServiceDiscoveryClientMdns::CreateLocalDomainResolver(
359 const std::string& domain,
360 net::AddressFamily address_family,
361 LocalDomainResolver::IPAddressCallback callback) {
362 DCHECK_CURRENTLY_ON(BrowserThread::UI);
363 return std::make_unique<LocalDomainResolverProxy>(
364 this, domain, address_family, std::move(callback));
365 }
366
~ServiceDiscoveryClientMdns()367 ServiceDiscoveryClientMdns::~ServiceDiscoveryClientMdns() {
368 DCHECK_CURRENTLY_ON(BrowserThread::UI);
369 content::GetNetworkConnectionTracker()->RemoveNetworkConnectionObserver(this);
370 DestroyMdns();
371 }
372
OnConnectionChanged(network::mojom::ConnectionType type)373 void ServiceDiscoveryClientMdns::OnConnectionChanged(
374 network::mojom::ConnectionType type) {
375 DCHECK_CURRENTLY_ON(BrowserThread::UI);
376 // Only network changes resets counter.
377 restart_attempts_ = 0;
378 ScheduleStartNewClient();
379 }
380
ScheduleStartNewClient()381 void ServiceDiscoveryClientMdns::ScheduleStartNewClient() {
382 DCHECK_CURRENTLY_ON(BrowserThread::UI);
383 OnBeforeMdnsDestroy();
384 if (restart_attempts_ >= kMaxRestartAttempts)
385 return;
386
387 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
388 FROM_HERE,
389 base::BindOnce(&ServiceDiscoveryClientMdns::StartNewClient,
390 weak_ptr_factory_.GetWeakPtr()),
391 base::TimeDelta::FromSeconds(kRestartDelayOnNetworkChangeSeconds *
392 (1 << restart_attempts_)));
393 }
394
StartNewClient()395 void ServiceDiscoveryClientMdns::StartNewClient() {
396 DCHECK_CURRENTLY_ON(BrowserThread::UI);
397 ++restart_attempts_;
398 DestroyMdns();
399 mdns_ = net::MDnsClient::CreateDefault();
400 client_ = std::make_unique<ServiceDiscoveryClientImpl>(mdns_.get());
401 base::ThreadPool::PostTaskAndReplyWithResult(
402 FROM_HERE, {base::TaskPriority::BEST_EFFORT, base::MayBlock()},
403 base::BindOnce(&net::GetMDnsInterfacesToBind),
404 base::BindOnce(&ServiceDiscoveryClientMdns::OnInterfaceListReady,
405 weak_ptr_factory_.GetWeakPtr()));
406 }
407
OnInterfaceListReady(const net::InterfaceIndexFamilyList & interfaces)408 void ServiceDiscoveryClientMdns::OnInterfaceListReady(
409 const net::InterfaceIndexFamilyList& interfaces) {
410 DCHECK_CURRENTLY_ON(BrowserThread::UI);
411 mdns_runner_->PostTask(
412 FROM_HERE,
413 base::BindOnce(
414 &InitMdns,
415 base::BindOnce(&ServiceDiscoveryClientMdns::OnMdnsInitialized,
416 weak_ptr_factory_.GetWeakPtr()),
417 interfaces, base::Unretained(mdns_.get())));
418 }
419
OnMdnsInitialized(int net_error)420 void ServiceDiscoveryClientMdns::OnMdnsInitialized(int net_error) {
421 DCHECK_CURRENTLY_ON(BrowserThread::UI);
422 if (net_error != net::OK) {
423 ScheduleStartNewClient();
424 return;
425 }
426
427 // Initialization is done, no need to delay tasks.
428 need_delay_mdns_tasks_ = false;
429 for (Proxy& observer : proxies_)
430 observer.OnNewMdnsReady();
431 }
432
OnBeforeMdnsDestroy()433 void ServiceDiscoveryClientMdns::OnBeforeMdnsDestroy() {
434 need_delay_mdns_tasks_ = true;
435 weak_ptr_factory_.InvalidateWeakPtrs();
436 for (Proxy& observer : proxies_)
437 observer.OnMdnsDestroy();
438 }
439
DestroyMdns()440 void ServiceDiscoveryClientMdns::DestroyMdns() {
441 OnBeforeMdnsDestroy();
442 // After calling Proxy::OnMdnsDestroy(), all references to |client_| and
443 // |mdns_| should be destroyed.
444 if (client_)
445 mdns_runner_->DeleteSoon(FROM_HERE, client_.release());
446 if (mdns_)
447 mdns_runner_->DeleteSoon(FROM_HERE, mdns_.release());
448 }
449
450 } // namespace local_discovery
451