1 #ifndef MEM_BACKTRACK_CDJOICDSJ
2 #define MEM_BACKTRACK_CDJOICDSJ
3 #include "library/vec1.hpp"
4 #include "library/free_any_object.hpp"
5 #include "library/library.hpp"
6 #include <utility>
7 #include <set>
8 
9 class MemoryBacktracker;
10 
11 class BacktrackableType
12 {
13     MemoryBacktracker* mb;
14 public:
15     BacktrackableType(MemoryBacktracker* _mb);
16     virtual ~BacktrackableType();
17 
event_pushWorld()18     virtual void event_pushWorld(){}
event_popWorld()19     virtual void event_popWorld() {}
20 };
21 
22 
23 template<typename T>
24 class Reverting
25 {
26     MemoryBacktracker* mb;
27     T* val;
28 public:
Reverting(MemoryBacktracker * _mb,T * _t)29     Reverting(MemoryBacktracker* _mb, T* _t)
30     : mb(_mb), val(_t)
31     { }
32 
Reverting()33     Reverting()
34     : mb(NULL), val(NULL)
35     { }
36 
Reverting(const Reverting & bt)37     Reverting(const Reverting& bt):
38     mb(bt.mb), val(bt.val)
39     { }
40 
get() const41     T get() const
42     { return *val; }
43 
44     void set(const T& val);
45     void set(T&& val);
46 };
47 
48 template<typename T>
49 class RevertingStack
50 {
51     MemoryBacktracker* mb;
52     vec1<T>* stack;
53 public:
RevertingStack(MemoryBacktracker * _mb,vec1<T> * _t)54     RevertingStack(MemoryBacktracker* _mb, vec1<T>* _t)
55     : mb(_mb), stack(_t)
56     { }
57 
RevertingStack()58     RevertingStack()
59     : mb(NULL), stack(NULL)
60     { }
61 
RevertingStack(const RevertingStack & bt)62     RevertingStack(const RevertingStack& bt):
63     mb(bt.mb), stack(bt.stack)
64     { }
65 
66     void push_back(const T& t);
67     void push_back(T&& t);
68 
back()69     const T& back()
70     { return stack->back(); }
71 
get() const72     const vec1<T>& get() const
73     { return *stack; }
74 
size() const75     int size() const
76     { return stack->size(); }
77 
78     // Why put 'dangerous' in this method name?
79     // Because only the size is backtracked, not the contents!
getMutable_dangerous()80     vec1<T>& getMutable_dangerous()
81     { return *stack; }
82 
83     // This can only be called before the first 'push'.
84     // we will notice that, on the later backtrack.
clearStack_dangerous()85     void clearStack_dangerous()
86     { stack->clear(); }
87 };
88 
89 template<typename T>
resizeBacktrackStack(void * ptr,int val)90 void resizeBacktrackStack(void* ptr, int val)
91 {
92     T* stack_ptr = (T*)ptr;
93     D_ASSERT(stack_ptr->size() >= val);
94     stack_ptr->resize(val);
95 }
96 
97 typedef void(*backtrack_function)(void*, int);
98 
99 struct BacktrackObj
100 {
101     backtrack_function fun;
102     void* ptr;
103     int data;
104 
operator ()BacktrackObj105     void operator()()
106     { fun(ptr, data); }
107 };
108 
109 
110 
111 
112 class MemoryBacktracker
113 {
114     vec1<vec1<std::pair<int*, int> > > reversions;
115     vec1<vec1<BacktrackObj> > function_reversions;
116 
117     // This just stores all the objects we allocated,
118     // so we can clean up at the end.
119     vec1<void*> raw_mem_store;
120     vec1<FreeObj> stack_mem_store;
121 
122     std::set<BacktrackableType*> objects_to_notify;
123 public:
registerBacktrackableObject(BacktrackableType * bt)124     void registerBacktrackableObject(BacktrackableType* bt)
125     { objects_to_notify.insert(bt); }
126 
unregisterBacktrackableObject(BacktrackableType * bt)127     void unregisterBacktrackableObject(BacktrackableType* bt)
128     { objects_to_notify.erase(bt); }
129 private:
130 
131 // forbid copying
132     MemoryBacktracker(const MemoryBacktracker&);
133 
134 public:
MemoryBacktracker()135     MemoryBacktracker() : reversions(1), function_reversions(1)
136     { }
137 
~MemoryBacktracker()138     ~MemoryBacktracker()
139     {
140         for(int i : range1(raw_mem_store.size()))
141             free(raw_mem_store[i]);
142         for(int i : range1(stack_mem_store.size()))
143             stack_mem_store[i]();
144     }
145 
storeCurrentValue(int * ptr)146     void storeCurrentValue(int* ptr)
147     {
148         reversions.back().push_back(std::make_pair(ptr, *ptr));
149     }
150 
151     template<typename Vec>
storeCurrentSize(Vec * ptr)152     void storeCurrentSize(Vec* ptr)
153     {
154         BacktrackObj obj;
155         obj.fun = resizeBacktrackStack<Vec>;
156         obj.ptr = ptr;
157         obj.data = ptr->size();
158         function_reversions.back().push_back(obj);
159     }
160 
pushWorld()161     void pushWorld()
162     {
163         debug_out(1, "MemoryManager", "pushWorld");
164         for(auto it : objects_to_notify)
165             it->event_pushWorld();
166 
167         reversions.resize(reversions.size() + 1);
168         function_reversions.resize(function_reversions.size() + 1);
169     }
170 
popWorld()171     void popWorld()
172     {
173         debug_out(1, "MemoryManager", "popWorld");
174 
175         // Need to go through last state, in reverse order.
176         vec1<std::pair<int*, int> >& backref = reversions.back();
177         for(int i = backref.size(); i >= 1; --i)
178         {
179             *(backref[i].first) = backref[i].second;
180         }
181         reversions.pop_back();
182 
183         vec1<BacktrackObj>& stackbackref = function_reversions.back();
184         for(int i = stackbackref.size(); i >= 1; --i)
185         {
186             (stackbackref[i])();
187         }
188         function_reversions.pop_back();
189 
190         for(auto it = objects_to_notify.rbegin();
191             it != objects_to_notify.rend(); ++it)
192             (*it)->event_popWorld();
193     }
194 
getDepth()195     int getDepth()
196     { return reversions.size(); }
197 
popWorldToDepth(int i)198     void popWorldToDepth(int i)
199     {
200         D_ASSERT(reversions.size() >= i);
201         while(reversions.size() > i)
202             popWorld();
203     }
204 
205     template<typename T>
makeReverting()206     Reverting<T> makeReverting()
207     {
208         void* ptr = calloc(1, sizeof(T));
209         raw_mem_store.push_back(ptr);
210         return Reverting<T>(this, (T*)ptr);
211     }
212 
213     template<typename T>
makeReverting(const T & t)214     Reverting<T> makeReverting(const T& t)
215     {
216         Reverting<T> r = makeReverting<T>();
217         r.set(t);
218         return r;
219     }
220 
221     template<typename T>
makeRevertingStack()222     RevertingStack<T> makeRevertingStack()
223     {
224         vec1<T>* ptr = new vec1<T>();
225         stack_mem_store.push_back(makeFreeObj(ptr));
226         return RevertingStack<T>(this, ptr);
227     }
228 };
229 
230 template<typename T>
set(const T & t)231 void Reverting<T>::set(const T& t)
232 {
233     mb->storeCurrentValue(val);
234     *val = t;
235 }
236 
237 template<typename T>
set(T && t)238 void Reverting<T>::set(T&& t)
239 {
240     mb->storeCurrentValue(val);
241     *val = std::move(t);
242 }
243 
244 template<typename T>
push_back(const T & t)245 void RevertingStack<T>::push_back(const T& t)
246 {
247     mb->storeCurrentSize(stack);
248     stack->push_back(t);
249 }
250 
251 template<typename T>
push_back(T && t)252 void RevertingStack<T>::push_back(T&& t)
253 {
254     mb->storeCurrentSize(stack);
255     stack->push_back(std::move(t));
256 }
257 
BacktrackableType(MemoryBacktracker * _mb)258 BacktrackableType::BacktrackableType(MemoryBacktracker* _mb)
259 : mb(_mb)
260 { if(mb != NULL) mb->registerBacktrackableObject(this); }
261 
~BacktrackableType()262 BacktrackableType::~BacktrackableType()
263 { if(mb != NULL) mb->unregisterBacktrackableObject(this); }
264 
265 #endif
266