1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <atomic>
20 #include <cstdint>
21 #include <stdexcept>
22 #include <type_traits>
23 
24 #include <glog/logging.h>
25 
26 struct counted_shared_tag {};
27 template <template <typename> class Atom = std::atomic>
28 struct intrusive_shared_count {
intrusive_shared_countintrusive_shared_count29   intrusive_shared_count() { counts.store(0); }
30   void add_ref(uint64_t count = 1) { counts.fetch_add(count); }
31 
32   uint64_t release_ref(uint64_t count = 1) { return counts.fetch_sub(count); }
33   Atom<uint64_t> counts;
34 };
35 
36 template <template <typename> class Atom = std::atomic>
37 struct counted_ptr_base {
38  protected:
getRefcounted_ptr_base39   static intrusive_shared_count<Atom>* getRef(void* pt) {
40     char* p = (char*)pt;
41     p -= sizeof(intrusive_shared_count<Atom>);
42     return (intrusive_shared_count<Atom>*)p;
43   }
44 };
45 
46 // basically shared_ptr, but only supports make_counted, and provides
47 // access to add_ref / release_ref with a count.  Alias not supported.
48 template <typename T, template <typename> class Atom = std::atomic>
49 class counted_ptr : public counted_ptr_base<Atom> {
50  public:
51   T* p_;
counted_ptr()52   counted_ptr() : p_(nullptr) {}
counted_ptr(counted_shared_tag,T * p)53   counted_ptr(counted_shared_tag, T* p) : p_(p) {
54     if (p_) {
55       counted_ptr_base<Atom>::getRef(p_)->add_ref();
56     }
57   }
58 
counted_ptr(const counted_ptr & o)59   counted_ptr(const counted_ptr& o) : p_(o.p_) {
60     if (p_) {
61       counted_ptr_base<Atom>::getRef(p_)->add_ref();
62     }
63   }
64   counted_ptr& operator=(const counted_ptr& o) {
65     if (p_ && counted_ptr_base<Atom>::getRef(p_)->release_ref() == 1) {
66       p_->~T();
67       free(counted_ptr_base<Atom>::getRef(p_));
68     }
69     p_ = o.p_;
70     if (p_) {
71       counted_ptr_base<Atom>::getRef(p_)->add_ref();
72     }
73     return *this;
74   }
counted_ptr(T * p)75   explicit counted_ptr(T* p) : p_(p) { CHECK(!p); }
~counted_ptr()76   ~counted_ptr() {
77     if (p_ && counted_ptr_base<Atom>::getRef(p_)->release_ref() == 1) {
78       p_->~T();
79       free(counted_ptr_base<Atom>::getRef(p_));
80     }
81   }
82   typename std::add_lvalue_reference<T>::type operator*() const { return *p_; }
83 
get()84   T* get() const { return p_; }
85   T* operator->() const { return p_; }
86   explicit operator bool() const { return p_ == nullptr ? false : true; }
87   bool operator==(const counted_ptr<T, Atom>& p) const {
88     return get() == p.get();
89   }
90 };
91 
92 template <
93     template <typename> class Atom = std::atomic,
94     typename T,
95     typename... Args>
make_counted(Args &&...args)96 counted_ptr<T, Atom> make_counted(Args&&... args) {
97   char* mem = (char*)malloc(sizeof(T) + sizeof(intrusive_shared_count<Atom>));
98   if (!mem) {
99     throw std::bad_alloc();
100   }
101   new (mem) intrusive_shared_count<Atom>();
102   T* ptr = (T*)(mem + sizeof(intrusive_shared_count<Atom>));
103   new (ptr) T(std::forward<Args>(args)...);
104   return counted_ptr<T, Atom>(counted_shared_tag(), ptr);
105 }
106 
107 template <template <typename> class Atom = std::atomic>
108 class counted_ptr_internals : public counted_ptr_base<Atom> {
109  public:
110   template <typename T, typename... Args>
make_ptr(Args &&...args)111   static counted_ptr<T, Atom> make_ptr(Args&&... args) {
112     return make_counted<Atom, T>(std::forward<Args...>(args...));
113   }
114   template <typename T>
115   using CountedPtr = counted_ptr<T, Atom>;
116   typedef void counted_base;
117 
118   template <typename T>
get_counted_base(const counted_ptr<T,Atom> & bar)119   static counted_base* get_counted_base(const counted_ptr<T, Atom>& bar) {
120     return bar.p_;
121   }
122 
123   template <typename T>
get_shared_ptr(counted_base * base)124   static T* get_shared_ptr(counted_base* base) {
125     return (T*)base;
126   }
127 
128   template <typename T>
release_ptr(counted_ptr<T,Atom> & p)129   static T* release_ptr(counted_ptr<T, Atom>& p) {
130     auto res = p.p_;
131     p.p_ = nullptr;
132     return res;
133   }
134 
135   template <typename T>
136   static counted_ptr<T, Atom> get_shared_ptr_from_counted_base(
137       counted_base* base, bool inc = true) {
138     auto res = counted_ptr<T, Atom>(counted_shared_tag(), (T*)(base));
139     if (!inc) {
140       release_shared<T>(base, 1);
141     }
142     return res;
143   }
144 
inc_shared_count(counted_base * base,int64_t count)145   static void inc_shared_count(counted_base* base, int64_t count) {
146     counted_ptr_base<Atom>::getRef(base)->add_ref(count);
147   }
148 
149   template <typename T>
release_shared(counted_base * base,uint64_t count)150   static void release_shared(counted_base* base, uint64_t count) {
151     if (count == counted_ptr_base<Atom>::getRef(base)->release_ref(count)) {
152       ((T*)base)->~T();
153       free(counted_ptr_base<Atom>::getRef(base));
154     }
155   }
156 };
157