1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_DNS_MDNS_CLIENT_IMPL_H_ 6 #define NET_DNS_MDNS_CLIENT_IMPL_H_ 7 8 #include <stdint.h> 9 10 #include <map> 11 #include <memory> 12 #include <string> 13 #include <utility> 14 #include <vector> 15 16 #include "base/cancelable_callback.h" 17 #include "base/containers/queue.h" 18 #include "base/gtest_prod_util.h" 19 #include "base/macros.h" 20 #include "base/observer_list.h" 21 #include "net/base/io_buffer.h" 22 #include "net/base/ip_endpoint.h" 23 #include "net/base/net_export.h" 24 #include "net/dns/mdns_cache.h" 25 #include "net/dns/mdns_client.h" 26 #include "net/socket/datagram_server_socket.h" 27 #include "net/socket/udp_server_socket.h" 28 #include "net/socket/udp_socket.h" 29 30 namespace base { 31 class Clock; 32 class OneShotTimer; 33 } // namespace base 34 35 namespace net { 36 37 class NetLog; 38 39 class MDnsSocketFactoryImpl : public MDnsSocketFactory { 40 public: MDnsSocketFactoryImpl()41 MDnsSocketFactoryImpl() : net_log_(nullptr) {} MDnsSocketFactoryImpl(NetLog * net_log)42 explicit MDnsSocketFactoryImpl(NetLog* net_log) : net_log_(net_log) {} ~MDnsSocketFactoryImpl()43 ~MDnsSocketFactoryImpl() override {} 44 45 void CreateSockets( 46 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override; 47 48 private: 49 NetLog* const net_log_; 50 51 DISALLOW_COPY_AND_ASSIGN(MDnsSocketFactoryImpl); 52 }; 53 54 // A connection to the network for multicast DNS clients. It reads data into 55 // DnsResponse objects and alerts the delegate that a packet has been received. 56 class NET_EXPORT_PRIVATE MDnsConnection { 57 public: 58 class Delegate { 59 public: 60 // Handle an mDNS packet buffered in |response| with a size of |bytes_read|. 61 virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0; 62 virtual void OnConnectionError(int error) = 0; ~Delegate()63 virtual ~Delegate() {} 64 }; 65 66 explicit MDnsConnection(MDnsConnection::Delegate* delegate); 67 virtual ~MDnsConnection(); 68 69 // Succeeds if at least one of the socket handlers succeeded. 70 int Init(MDnsSocketFactory* socket_factory); 71 void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size); 72 73 private: 74 class SocketHandler { 75 public: 76 SocketHandler(std::unique_ptr<DatagramServerSocket> socket, 77 MDnsConnection* connection); 78 ~SocketHandler(); 79 80 int Start(); 81 void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size); 82 83 private: 84 int DoLoop(int rv); 85 void OnDatagramReceived(int rv); 86 87 // Callback for when sending a query has finished. 88 void SendDone(int rv); 89 90 std::unique_ptr<DatagramServerSocket> socket_; 91 MDnsConnection* connection_; 92 IPEndPoint recv_addr_; 93 DnsResponse response_; 94 IPEndPoint multicast_addr_; 95 bool send_in_progress_; 96 base::queue<std::pair<scoped_refptr<IOBuffer>, unsigned>> send_queue_; 97 98 DISALLOW_COPY_AND_ASSIGN(SocketHandler); 99 }; 100 101 // Callback for handling a datagram being received on either ipv4 or ipv6. 102 void OnDatagramReceived(DnsResponse* response, 103 const IPEndPoint& recv_addr, 104 int bytes_read); 105 106 void PostOnError(SocketHandler* loop, int rv); 107 void OnError(int rv); 108 109 // Only socket handlers which successfully bound and started are kept. 110 std::vector<std::unique_ptr<SocketHandler>> socket_handlers_; 111 112 Delegate* delegate_; 113 114 base::WeakPtrFactory<MDnsConnection> weak_ptr_factory_{this}; 115 116 DISALLOW_COPY_AND_ASSIGN(MDnsConnection); 117 }; 118 119 class MDnsListenerImpl; 120 121 class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient { 122 public: 123 // The core object exists while the MDnsClient is listening, and is deleted 124 // whenever the number of listeners reaches zero. The deletion happens 125 // asychronously, so destroying the last listener does not immediately 126 // invalidate the core. 127 class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate { 128 public: 129 Core(base::Clock* clock, base::OneShotTimer* timer); 130 ~Core() override; 131 132 // Initialize the core. 133 int Init(MDnsSocketFactory* socket_factory); 134 135 // Send a query with a specific rrtype and name. Returns true on success. 136 bool SendQuery(uint16_t rrtype, const std::string& name); 137 138 // Add/remove a listener to the list of listeners. 139 void AddListener(MDnsListenerImpl* listener); 140 void RemoveListener(MDnsListenerImpl* listener); 141 142 // Query the cache for records of a specific type and name. 143 void QueryCache(uint16_t rrtype, 144 const std::string& name, 145 std::vector<const RecordParsed*>* records) const; 146 147 // Parse the response and alert relevant listeners. 148 void HandlePacket(DnsResponse* response, int bytes_read) override; 149 150 void OnConnectionError(int error) override; 151 cache_for_testing()152 MDnsCache* cache_for_testing() { return &cache_; } 153 154 private: 155 FRIEND_TEST_ALL_PREFIXES(MDnsTest, CacheCleanupWithShortTTL); 156 157 class ListenerKey { 158 public: 159 ListenerKey(const std::string& name, uint16_t type); 160 ListenerKey(const ListenerKey&) = default; 161 ListenerKey(ListenerKey&&) = default; 162 bool operator<(const ListenerKey& key) const; name_lowercase()163 const std::string& name_lowercase() const { return name_lowercase_; } type()164 uint16_t type() const { return type_; } 165 166 private: 167 std::string name_lowercase_; 168 uint16_t type_; 169 }; 170 typedef base::ObserverList<MDnsListenerImpl>::Unchecked ObserverListType; 171 typedef std::map<ListenerKey, std::unique_ptr<ObserverListType>> 172 ListenerMap; 173 174 // Alert listeners of an update to the cache. 175 void AlertListeners(MDnsCache::UpdateType update_type, 176 const ListenerKey& key, const RecordParsed* record); 177 178 // Schedule a cache cleanup to a specific time, cancelling other cleanups. 179 void ScheduleCleanup(base::Time cleanup); 180 181 // Clean up the cache and schedule a new cleanup. 182 void DoCleanup(); 183 184 // Callback for when a record is removed from the cache. 185 void OnRecordRemoved(const RecordParsed* record); 186 187 void NotifyNsecRecord(const RecordParsed* record); 188 189 // Delete and erase the observer list for |key|. Only deletes the observer 190 // list if is empty. 191 void CleanupObserverList(const ListenerKey& key); 192 193 ListenerMap listeners_; 194 195 MDnsCache cache_; 196 197 base::Clock* clock_; 198 base::OneShotTimer* cleanup_timer_; 199 base::Time scheduled_cleanup_; 200 201 std::unique_ptr<MDnsConnection> connection_; 202 203 DISALLOW_COPY_AND_ASSIGN(Core); 204 }; 205 206 MDnsClientImpl(); 207 208 // Test constructor, takes a mock clock and mock timer. 209 MDnsClientImpl(base::Clock* clock, 210 std::unique_ptr<base::OneShotTimer> cleanup_timer); 211 212 ~MDnsClientImpl() override; 213 214 // MDnsClient implementation: 215 std::unique_ptr<MDnsListener> CreateListener( 216 uint16_t rrtype, 217 const std::string& name, 218 MDnsListener::Delegate* delegate) override; 219 220 std::unique_ptr<MDnsTransaction> CreateTransaction( 221 uint16_t rrtype, 222 const std::string& name, 223 int flags, 224 const MDnsTransaction::ResultCallback& callback) override; 225 226 int StartListening(MDnsSocketFactory* socket_factory) override; 227 void StopListening() override; 228 bool IsListening() const override; 229 core()230 Core* core() { return core_.get(); } 231 232 private: 233 base::Clock* clock_; 234 std::unique_ptr<base::OneShotTimer> cleanup_timer_; 235 236 std::unique_ptr<Core> core_; 237 238 DISALLOW_COPY_AND_ASSIGN(MDnsClientImpl); 239 }; 240 241 class MDnsListenerImpl : public MDnsListener, 242 public base::SupportsWeakPtr<MDnsListenerImpl> { 243 public: 244 MDnsListenerImpl(uint16_t rrtype, 245 const std::string& name, 246 base::Clock* clock, 247 MDnsListener::Delegate* delegate, 248 MDnsClientImpl* client); 249 250 ~MDnsListenerImpl() override; 251 252 // MDnsListener implementation: 253 bool Start() override; 254 255 // Actively refresh any received records. 256 void SetActiveRefresh(bool active_refresh) override; 257 258 const std::string& GetName() const override; 259 260 uint16_t GetType() const override; 261 delegate()262 MDnsListener::Delegate* delegate() { return delegate_; } 263 264 // Alert the delegate of a record update. 265 void HandleRecordUpdate(MDnsCache::UpdateType update_type, 266 const RecordParsed* record_parsed); 267 268 // Alert the delegate of the existence of an Nsec record. 269 void AlertNsecRecord(); 270 271 private: 272 void ScheduleNextRefresh(); 273 void DoRefresh(); 274 275 uint16_t rrtype_; 276 std::string name_; 277 base::Clock* clock_; 278 MDnsClientImpl* client_; 279 MDnsListener::Delegate* delegate_; 280 281 base::Time last_update_; 282 uint32_t ttl_; 283 bool started_; 284 bool active_refresh_; 285 286 base::CancelableClosure next_refresh_; 287 DISALLOW_COPY_AND_ASSIGN(MDnsListenerImpl); 288 }; 289 290 class MDnsTransactionImpl : public base::SupportsWeakPtr<MDnsTransactionImpl>, 291 public MDnsTransaction, 292 public MDnsListener::Delegate { 293 public: 294 MDnsTransactionImpl(uint16_t rrtype, 295 const std::string& name, 296 int flags, 297 const MDnsTransaction::ResultCallback& callback, 298 MDnsClientImpl* client); 299 ~MDnsTransactionImpl() override; 300 301 // MDnsTransaction implementation: 302 bool Start() override; 303 304 const std::string& GetName() const override; 305 uint16_t GetType() const override; 306 307 // MDnsListener::Delegate implementation: 308 void OnRecordUpdate(MDnsListener::UpdateType update, 309 const RecordParsed* record) override; 310 void OnNsecRecord(const std::string& name, unsigned type) override; 311 312 void OnCachePurged() override; 313 314 private: is_active()315 bool is_active() { return !callback_.is_null(); } 316 317 void Reset(); 318 319 // Trigger the callback and reset all related variables. 320 void TriggerCallback(MDnsTransaction::Result result, 321 const RecordParsed* record); 322 323 // Internal callback for when a cache record is found. 324 void CacheRecordFound(const RecordParsed* record); 325 326 // Signal the transactionis over and release all related resources. 327 void SignalTransactionOver(); 328 329 // Reads records from the cache and calls the callback for every 330 // record read. 331 void ServeRecordsFromCache(); 332 333 // Send a query to the network and set up a timeout to time out the 334 // transaction. Returns false if it fails to start listening on the network 335 // or if it fails to send a query. 336 bool QueryAndListen(); 337 338 uint16_t rrtype_; 339 std::string name_; 340 MDnsTransaction::ResultCallback callback_; 341 342 std::unique_ptr<MDnsListener> listener_; 343 base::CancelableOnceCallback<void()> timeout_; 344 345 MDnsClientImpl* client_; 346 347 bool started_; 348 int flags_; 349 350 DISALLOW_COPY_AND_ASSIGN(MDnsTransactionImpl); 351 }; 352 353 } // namespace net 354 #endif // NET_DNS_MDNS_CLIENT_IMPL_H_ 355