1 #include <vector>
2 #include <queue>
3 #include <iterator>
4 #include <functional>
5 #include <algorithm>
6 
7 #include <boost/functional.hpp>
8 
9 #include "OS.h"
10 
11 namespace search {
12 
13 template<typename T>
14 struct successor_traits {
15   typedef typename T::action_type action_type;
16 };
17 
18 template<typename T>
19 struct action_traits {
20   typedef typename boost::unary_traits<T>::result_type delta_type;
21 };
22 
23 template<typename T>
24 struct evaluator_traits {
25   typedef typename boost::unary_traits<T>::result_type score_type;
26 };
27 
28 template<typename T>
29 struct selector_traits {
30   typedef typename boost::unary_traits<T>::result_type fragment_type;
31 };
32 
33 template<typename T>
34 struct search_traits {
35   typedef typename boost::binary_traits<T>::result_type assignment_type;
36 };
37 
38 template<typename State,
39          typename Visitor,
40          typename Successor>
depth_first_search(State & init,Visitor & visit,Successor & fn)41 void depth_first_search(State &init, Visitor &visit, Successor &fn) {
42   typedef typename successor_traits<Successor>::action_type action;
43   typedef typename action_traits<action>::delta_type delta;
44 
45   visit(init);
46 
47   std::vector<action> actions;
48   fn(init, std::back_inserter(actions));
49 
50   for (typename std::vector<action>::const_iterator it = actions.begin();
51        it != actions.end(); it++) {
52     const action &action = *it;
53     delta change(action(init));
54 
55     change.apply(init);
56     depth_first_search(init, visit, fn);
57     change.reverse(init);
58   }
59 }
60 
61 template<typename State,
62          typename Visitor,
63          typename Successor>
depth_limited_search(State & init,Visitor & visit,Successor & fn,size_t max_depth)64 void depth_limited_search(State &init, Visitor &visit, Successor &fn,
65                           size_t max_depth) {
66   typedef typename successor_traits<Successor>::action_type action;
67   typedef typename action_traits<action>::delta_type delta;
68 
69   visit(init);
70   if (max_depth == 0) return;
71 
72   std::vector<action> actions;
73   fn(init, std::back_inserter(actions));
74 
75   for (typename std::vector<action>::const_iterator it = actions.begin();
76        it != actions.end(); it++) {
77     const action &action = *it;
78     delta change(action(init));
79 
80     change.apply(init);
81     depth_limited_search(init, visit, fn, max_depth - 1);
82     change.reverse(init);
83   }
84 }
85 
86 template<typename State,
87          typename Visitor,
88          typename Successor,
89          typename Queue>
tree_search(State & init,Visitor & visit,Successor & fn,Queue q)90 void tree_search(State &init, Visitor &visit, Successor &fn, Queue q) {
91   typedef typename successor_traits<Successor>::action_type action;
92 
93   q.push(init);
94 
95   while (!q.empty()) {
96     State s = q.top();
97     q.pop();
98 
99     visit(s);
100 
101     std::vector<action> actions;
102     fn(s, std::back_inserter(actions));
103 
104     for (typename std::vector<action>::const_iterator it = actions.begin();
105        it != actions.end(); it++) {
106       const action &action = *it;
107 
108       State to_add(s);
109       action(s).apply(to_add);
110 
111       q.push(to_add);
112     }
113   }
114 }
115 
116 template<typename State,
117          typename Visitor,
118          typename Successor>
breadth_first_search(State & init,Visitor & visit,Successor & fn)119 void breadth_first_search(State &init, Visitor &visit, Successor &fn) {
120   tree_search(init, visit, fn, std::queue<State>());
121 }
122 
123 template<typename State,
124          typename Visitor,
125          typename Successor,
126          typename Compare = std::less<State> >
best_first_search(State & init,Visitor & visit,Successor & fn,Compare c=Compare ())127 void best_first_search(State &init, Visitor &visit, Successor &fn,
128                        Compare c = Compare()) {
129   tree_search(init, visit, fn, std::priority_queue<State, Compare>(c));
130 }
131 
132 /**
133  * At each step, the successor returned through fn which maximizes eval, until
134  * reaching a node which has no successors.
135  */
136 template<typename State,
137          typename Successor,
138          typename Evaluator>
greedy_search(State & state,Successor & fn,Evaluator & eval)139 void greedy_search(State &state, Successor &fn, Evaluator &eval) {
140   typedef typename successor_traits<Successor>::action_type action;
141   typedef typename action_traits<action>::delta_type delta;
142   typedef typename evaluator_traits<Evaluator>::score_type score;
143 
144   while (true) {
145     std::vector<action> actions;
146     fn(state, std::back_inserter(actions));
147 
148     if (actions.empty()) return;
149 
150     if (actions.size() == 1) {
151       actions[0](state).apply(state);
152       continue;
153     }
154 
155     std::vector<delta> changes(actions.size());
156 
157     typename std::vector<action>::const_iterator actions_it;
158     typename std::vector<delta>::iterator changes_it;
159     for (actions_it = actions.begin(), changes_it = changes.begin();
160          actions_it != actions.end(); actions_it++, changes_it++) {
161       const action &a = *actions_it;
162       *changes_it = a(state);
163     }
164 
165     std::vector<score> scores(actions.size());
166 
167     typename std::vector<score>::iterator scores_it;
168     for (changes_it = changes.begin(), scores_it = scores.begin();
169          changes_it != changes.end(); changes_it++, scores_it++) {
170       const delta &delta = *changes_it;
171 
172       delta.apply(state);
173       *scores_it = eval(state);
174       delta.reverse(state);
175     }
176 
177     typename std::vector<score>::iterator it =
178       std::max_element(scores.begin(), scores.end());
179 
180     size_t index = std::distance(scores.begin(), it);
181     changes[index].apply(state);
182   }
183 }
184 
185 template<typename State,
186          typename Selector,
187          typename Search>
large_neighborhood_search(State & state,Selector & selector,Search & search,std::size_t time_limit=100)188 void large_neighborhood_search(State &state, Selector &selector,
189                                Search &search, std::size_t time_limit = 100) {
190   typedef typename selector_traits<Selector>::fragment_type fragment;
191   typedef typename search_traits<Search>::assignment_type assignment;
192 
193   /* TODO: Use a monotonic clock instead */
194   double start = TimeOfDay();
195   while (TimeOfDay() - start < time_limit) {
196     fragment fragment(selector(state));
197     assignment assignment(search(state, fragment));
198     assignment(state, fragment);
199   }
200 }
201 
202 template<typename Score>
203 struct mcts_node {
mcts_nodesearch::mcts_node204   mcts_node():
205     x_1(0), x_2(0), visit_count(0), children(0) {}
206 
207   Score x_1, x_2;
208   size_t visit_count;
209 
210   std::vector<mcts_node> children;
211 
updatesearch::mcts_node212   void update(Score score) {
213     x_1 += score;
214     x_2 += score*score;
215     visit_count++;
216   }
217 
ucb1search::mcts_node218   Score ucb1(std::size_t n) const {
219     if (visit_count == 0)
220       return std::numeric_limits<Score>::max();
221     else
222       return (x_1/visit_count) + std::sqrt(2*std::log(Score(n))/visit_count);
223   }
224 
225   class compare_by_ucb1 {
226     std::size_t _n;
227   public:
compare_by_ucb1(std::size_t n)228     compare_by_ucb1(std::size_t n): _n(n) {}
229 
operator ()(const mcts_node<Score> & a,const mcts_node<Score> & b) const230     bool operator()(const mcts_node<Score> &a, const mcts_node<Score> &b) const {
231      return a.ucb1(_n) < b.ucb1(_n);
232     }
233   };
234 
235   struct compare_by_mean {
operator ()search::mcts_node::compare_by_mean236     bool operator()(const mcts_node<Score> &a, const mcts_node<Score> &b) const {
237       return a.x_1 < b.x_1;
238     }
239   };
240 
best_childsearch::mcts_node241   std::size_t best_child(std::size_t n) const {
242     typename std::vector<mcts_node>::iterator it =
243       std::max_element(children.begin(), children.end(),
244                        compare_by_ucb1(n));
245     return std::distance(children.begin(), it);
246   }
247 
best_movesearch::mcts_node248   std::size_t best_move(std::size_t n) const {
249     typename std::vector<mcts_node>::iterator it =
250       std::max_element(children.begin(), children.end(),
251                        compare_by_mean());
252     return std::distance(children.begin(), it);
253   }
254 };
255 
256 /**
257  * Runs a single Monte Carlo Simulation as part of a Monte Carlo Tree Search.
258  *
259  * This operation is done as follows:
260  *   1. Explore the tree of statistics (node being a pointer to its root), playing
261  *      optimially according to those statistics, until reaching
262  *      one of its leaves.
263  *   2. Perform a completely random simulation, starting from the leaf found at
264  *      the previous step.
265  *   3. Update the statistics tree.
266  */
267 template<typename State,
268          typename Successor,
269          typename Evaluator,
270          typename Iterator>
mcts_simulation(mcts_node<typename evaluator_traits<Evaluator>::score_type> * node,State & state,Successor & fn,Evaluator & eval,size_t n,Iterator action_begin,Iterator action_end)271 void mcts_simulation(mcts_node<typename evaluator_traits<Evaluator>::score_type> *node,
272                      State &state, Successor &fn, Evaluator &eval, size_t n,
273                      Iterator action_begin, Iterator action_end) {
274   typedef typename successor_traits<Successor>::action_type action;
275   typedef typename action_traits<action>::delta_type delta;
276   typedef typename evaluator_traits<Evaluator>::score_type score;
277 
278   std::vector<action> actions(action_begin, action_end);
279 
280   std::vector<mcts_node<score>*> ancestors;
281   std::vector<delta> changes;
282 
283   ancestors.push_back(node);
284 
285   bool selection = true;
286 
287   while (!actions.empty()) {
288     std::size_t i;
289 
290     if (selection && node->children.size() == 0) {
291       node->children.resize(actions.size());
292 
293       i = rand() % actions.size();
294 
295       node = &node->children[i];
296       ancestors.push_back(node);
297 
298       selection = false;
299     }
300     else if (selection) {
301       i = node->best_child(n);
302       node = &node->children[i];
303       ancestors.push_back(node);
304     }
305     else {
306       i = rand() % actions.size();
307     }
308 
309     delta change = actions[i](state);
310     changes.push_back(change);
311 
312     change.apply(state);
313 
314     actions.clear();
315     fn(state, std::back_inserter(actions));
316   }
317 
318   score result(eval(state));
319 
320   for (typename std::vector<mcts_node<score>*>::const_iterator it = ancestors.begin();
321        it != ancestors.end(); it++) {
322     (*it)->update(result);
323   }
324 
325   for (typename std::vector<delta>::const_iterator it = changes.rbegin();
326        it != changes.rend(); it++) {
327     (*it).reverse(state);
328   }
329 }
330 
331 template<typename State,
332          typename Successor,
333          typename Evaluator>
monte_carlo_tree_search(State & state,Successor & fn,Evaluator & eval)334 void monte_carlo_tree_search(State &state, Successor &fn, Evaluator &eval) {
335   typedef typename successor_traits<Successor>::action_type action;
336   typedef typename evaluator_traits<Evaluator>::score_type score;
337 
338   std::vector<action> actions;
339 
340   while (true) {
341     actions.clear();
342     fn(state, std::back_inserter(actions));
343 
344     std::cout << "\rscore: " << eval(state);
345     std::cout.flush();
346 
347     if (actions.empty())
348       break;
349     else if (actions.size() == 1) {
350       actions[0](state).apply(state);
351     }
352     else {
353       mcts_node<score> node;
354 
355       size_t i;
356 
357       /* TODO: Use a monotonic clock instead */
358       double start = TimeOfDay();
359       for (i = 0; (TimeOfDay() - start) < 1.0; i++) {
360         mcts_simulation(&node, state, fn, eval, i,
361                         actions.begin(), actions.end());
362       }
363 
364       actions[node.best_move(i)](state).apply(state);
365     }
366   }
367 }
368 
369 }
370