1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <limits>
6 
7 #include "base/atomic_ref_count.h"
8 #include "base/no_destructor.h"
9 #include "net/base/features.h"
10 #include "net/socket/udp_socket_global_limits.h"
11 
12 namespace net {
13 
14 namespace {
15 
16 // Threadsafe singleton for tracking the process-wide count of UDP sockets.
17 class GlobalUDPSocketCounts {
18  public:
GlobalUDPSocketCounts()19   GlobalUDPSocketCounts() : count_(0) {}
20 
21   ~GlobalUDPSocketCounts() = delete;
22 
Get()23   static GlobalUDPSocketCounts& Get() {
24     static base::NoDestructor<GlobalUDPSocketCounts> singleton;
25     return *singleton;
26   }
27 
TryAcquireSocket()28   bool TryAcquireSocket() WARN_UNUSED_RESULT {
29     int previous = count_.Increment(1);
30     if (previous >= GetMax()) {
31       count_.Increment(-1);
32       return false;
33     }
34 
35     return true;
36   }
37 
GetMax()38   int GetMax() {
39     if (base::FeatureList::IsEnabled(features::kLimitOpenUDPSockets))
40       return features::kLimitOpenUDPSocketsMax.Get();
41 
42     return std::numeric_limits<int>::max();
43   }
44 
ReleaseSocket()45   void ReleaseSocket() { count_.Increment(-1); }
46 
GetCountForTesting()47   int GetCountForTesting() { return count_.SubtleRefCountForDebug(); }
48 
49  private:
50   base::AtomicRefCount count_;
51 };
52 
53 }  // namespace
54 
OwnedUDPSocketCount()55 OwnedUDPSocketCount::OwnedUDPSocketCount() : OwnedUDPSocketCount(true) {}
56 
OwnedUDPSocketCount(OwnedUDPSocketCount && other)57 OwnedUDPSocketCount::OwnedUDPSocketCount(OwnedUDPSocketCount&& other) {
58   *this = std::move(other);
59 }
60 
operator =(OwnedUDPSocketCount && other)61 OwnedUDPSocketCount& OwnedUDPSocketCount::operator=(
62     OwnedUDPSocketCount&& other) {
63   Reset();
64   empty_ = other.empty_;
65   other.empty_ = true;
66   return *this;
67 }
68 
~OwnedUDPSocketCount()69 OwnedUDPSocketCount::~OwnedUDPSocketCount() {
70   Reset();
71 }
72 
Reset()73 void OwnedUDPSocketCount::Reset() {
74   if (!empty_) {
75     GlobalUDPSocketCounts::Get().ReleaseSocket();
76     empty_ = true;
77   }
78 }
79 
OwnedUDPSocketCount(bool empty)80 OwnedUDPSocketCount::OwnedUDPSocketCount(bool empty) : empty_(empty) {}
81 
TryAcquireGlobalUDPSocketCount()82 OwnedUDPSocketCount TryAcquireGlobalUDPSocketCount() {
83   bool success = GlobalUDPSocketCounts::Get().TryAcquireSocket();
84   return OwnedUDPSocketCount(!success);
85 }
86 
GetGlobalUDPSocketCountForTesting()87 int GetGlobalUDPSocketCountForTesting() {
88   return GlobalUDPSocketCounts::Get().GetCountForTesting();
89 }
90 
91 }  // namespace net
92 
93