1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <cassert>
10 #include <cstddef>
11 #include <iostream>
12 #include <iterator>
13 #include <random>
14 #include <stdexcept>
15 #include <string>
16 #include <type_traits>
17 #include <variant>
18 #include <vector>
19 
20 #include <heyoka/detail/type_traits.hpp>
21 #include <heyoka/expression.hpp>
22 #include <heyoka/func.hpp>
23 #include <heyoka/gp.hpp>
24 #include <heyoka/math.hpp>
25 #include <heyoka/number.hpp>
26 #include <heyoka/variable.hpp>
27 
28 namespace heyoka
29 {
30 
31 namespace detail
32 {
33 
34 namespace
35 {
36 
37 template <typename It, typename Rng>
38 It random_element(It start, It end, Rng &g)
39 {
40     std::uniform_int_distribution<decltype(std::distance(start, end))> dis(0, std::distance(start, end) - 1);
41     std::advance(start, dis(g));
42     return start;
43 }
44 
fetch_from_node_id_impl(expression & ex,std::size_t node_id,std::size_t & node_counter,expression * & ret)45 void fetch_from_node_id_impl(expression &ex, std::size_t node_id, std::size_t &node_counter, expression *&ret)
46 {
47     if (node_counter == node_id) {
48         ret = &ex;
49     } else {
50         ++node_counter;
51         std::visit(
52             [node_id, &node_counter, &ret](auto &node) {
53                 using type = detail::uncvref_t<decltype(node)>;
54                 if constexpr (std::is_same_v<type, func>) {
55                     for (auto [b, e] = node.get_mutable_args_it(); b != e; ++b) {
56                         fetch_from_node_id_impl(*b, node_id, node_counter, ret);
57                         if (ret) {
58                             return;
59                         }
60                     }
61                 }
62             },
63             ex.value());
64     }
65 }
66 
count_nodes_impl(const expression & e,std::size_t & node_counter)67 void count_nodes_impl(const expression &e, std::size_t &node_counter)
68 {
69     ++node_counter;
70     std::visit(
71         [&node_counter](auto &node) {
72             using type = detail::uncvref_t<decltype(node)>;
73             if constexpr (std::is_same_v<type, func>) {
74                 for (auto &arg : node.args()) {
75                     count_nodes_impl(arg, node_counter);
76                 }
77             }
78         },
79         e.value());
80 }
81 
82 } // namespace
83 
84 } // namespace detail
85 
expression_generator(const std::vector<std::string> & vars,splitmix64 & engine)86 expression_generator::expression_generator(const std::vector<std::string> &vars, splitmix64 &engine)
87     : m_vars(vars), m_b_funcs(), m_e(engine)
88 {
89     // These are the default blocks for a random expression.
90     m_u_funcs = {heyoka::sin, heyoka::cos};
91     m_b_funcs = {heyoka::add, heyoka::sub, heyoka::mul, heyoka::div};
92     m_range_dbl = 10;
93     // Default weights for the probability of selecting a bo, unary, binary, a variable, a number.
94     m_weights = {8., 2., 1., 4., 1.};
95 };
96 
operator ()(unsigned min_depth,unsigned max_depth,unsigned depth) const97 expression expression_generator::operator()(unsigned min_depth, unsigned max_depth, unsigned depth) const
98 {
99     std::uniform_real_distribution<double> rng01(0.0, 1.0);
100     std::uniform_real_distribution<double> rngm11(-1.0, 1.0);
101 
102     // First we decide what node type this will be.
103     node_type type;
104     if (depth < min_depth) {
105         // If the node depth is below the minimum desired, we force leaves (num or var) to be not selected
106         double n_u_fun = static_cast<double>(m_u_funcs.size());
107         double n_b_fun = static_cast<double>(m_b_funcs.size());
108         std::discrete_distribution<> dis({0, n_u_fun * m_weights[1], n_b_fun * m_weights[2]});
109         switch (dis(m_e)) {
110             case 0:
111                 type = node_type::bo;
112                 break;
113             case 1:
114                 type = node_type::u_fun;
115                 break;
116             case 2:
117                 type = node_type::b_fun;
118                 break;
119         }
120     } else if (depth >= max_depth) {
121         // If the node depth is above the maximum desired, we force leaves (num or var) to be selected
122         double n_var = static_cast<double>(m_vars.size());
123         std::discrete_distribution<> dis({n_var * m_weights[3], m_weights[4]});
124         switch (dis(m_e)) {
125             case 0:
126                 type = node_type::var;
127                 break;
128             case 1:
129                 type = node_type::num;
130                 break;
131         }
132     } else {
133         // else we can get anything
134         double n_u_fun = static_cast<double>(m_u_funcs.size());
135         double n_b_fun = static_cast<double>(m_b_funcs.size());
136         double n_var = static_cast<double>(m_vars.size());
137         std::discrete_distribution<> dis(
138             {0, n_u_fun * m_weights[1], n_b_fun * m_weights[2], n_var * m_weights[3], m_weights[4]});
139         switch (dis(m_e)) {
140             case 0:
141                 type = node_type::bo;
142                 break;
143             case 1:
144                 type = node_type::u_fun;
145                 break;
146             case 2:
147                 type = node_type::b_fun;
148                 break;
149             case 3:
150                 type = node_type::var;
151                 break;
152             case 4:
153                 type = node_type::num;
154                 break;
155         }
156     }
157     // Once we know the node type we create one at random out of the user defined possible choices
158     switch (type) {
159         case node_type::num: {
160             // We return a random number in -m_range_dbl, m_range_dbl
161             auto value = rngm11(m_e) * m_range_dbl;
162             return expression{number{value}};
163             break;
164         }
165         case node_type::var: {
166             // We return one of the variables in m_vars
167             auto symbol = *detail::random_element(m_vars.begin(), m_vars.end(), m_e);
168             return expression{variable{symbol}};
169             break;
170         }
171         case node_type::u_fun: {
172             // We return one of the unary functions in m_u_funcs with randomly constructed argument
173             auto u_f = *detail::random_element(m_u_funcs.begin(), m_u_funcs.end(), m_e);
174             return u_f(this->operator()(min_depth, max_depth, depth + 1));
175             break;
176         }
177         case node_type::b_fun: {
178             // We return one of the binary functions in m_b_funcs with randomly constructed arguments
179             auto b_f = *detail::random_element(m_b_funcs.begin(), m_b_funcs.end(), m_e);
180             return b_f(this->operator()(min_depth, max_depth, depth + 1),
181                        this->operator()(min_depth, max_depth, depth + 1));
182             break;
183         }
184         default:
185             throw;
186     }
187 };
188 
get_u_funcs() const189 const std::vector<expression (*)(expression)> &expression_generator::get_u_funcs() const
190 {
191     return m_u_funcs;
192 }
get_b_funcs() const193 const std::vector<expression (*)(expression, expression)> &expression_generator::get_b_funcs() const
194 {
195     return m_b_funcs;
196 }
get_vars() const197 const std::vector<std::string> &expression_generator::get_vars() const
198 {
199     return m_vars;
200 }
get_range_dbl() const201 const double &expression_generator::get_range_dbl() const
202 {
203     return m_range_dbl;
204 }
get_weights() const205 const std::vector<double> &expression_generator::get_weights() const
206 {
207     return m_weights;
208 }
209 
set_u_funcs(const std::vector<expression (*)(expression)> & u_funcs)210 void expression_generator::set_u_funcs(const std::vector<expression (*)(expression)> &u_funcs)
211 {
212     m_u_funcs = u_funcs;
213 }
set_b_funcs(const std::vector<expression (*)(expression,expression)> & b_funcs)214 void expression_generator::set_b_funcs(const std::vector<expression (*)(expression, expression)> &b_funcs)
215 {
216     m_b_funcs = b_funcs;
217 }
set_vars(const std::vector<std::string> & vars)218 void expression_generator::set_vars(const std::vector<std::string> &vars)
219 {
220     m_vars = vars;
221 }
set_range_dbl(const double & rd)222 void expression_generator::set_range_dbl(const double &rd)
223 {
224     m_range_dbl = rd;
225 }
set_weights(const std::vector<double> & w)226 void expression_generator::set_weights(const std::vector<double> &w)
227 {
228     if (w.size() != 5) {
229         throw std::invalid_argument(
230             "The weight vector for the probablity distribution of the node type must have size "
231             "5 -> (binary operator, unary functions, binary functions, variable, numbers), while I detected a size of: "
232             + std::to_string(w.size()));
233     }
234     m_weights = w;
235 }
236 
operator <<(std::ostream & os,const expression_generator & eg)237 std::ostream &operator<<(std::ostream &os, const expression_generator &eg)
238 {
239     os << "Expression Generator:";
240     auto vars = eg.get_vars();
241 
242     os << "\nVariables: ";
243     for (const auto &var : vars) {
244         os << var << " ";
245     }
246     auto u_funcs = eg.get_u_funcs();
247     if (u_funcs.size()) {
248         os << "\nUnary Functions: ";
249         for (const auto &u_func : u_funcs) {
250             os << u_func("."_var) << " ";
251         }
252     }
253     auto b_funcs = eg.get_b_funcs();
254     if (b_funcs.size()) {
255         os << "\nBinary Functions: ";
256         for (const auto &b_func : b_funcs) {
257             os << b_func("."_var, "."_var) << " ";
258         }
259     }
260     os << "\nRandom double constants range: ";
261     os << "[-" << eg.get_range_dbl() << ", " << eg.get_range_dbl() << "]";
262     os << "\nWeights:";
263     os << "\n\tBinary operator: " << eg.get_weights()[0];
264     os << "\n\tUnary function: " << eg.get_weights()[1];
265     os << "\n\tBinary function: " << eg.get_weights()[2];
266     os << "\n\tVariable: " << eg.get_weights()[3];
267     os << "\n\tConstant: " << eg.get_weights()[4];
268 
269     return os << "\n";
270 }
271 
272 // Version randomly selecting nodes during traversal (PROBABLY WILL BE REMOVED)
mutate(expression & e,const expression_generator & generator,const double mut_p,splitmix64 & engine,const unsigned min_depth,const unsigned max_depth,unsigned depth)273 void mutate(expression &e, const expression_generator &generator, const double mut_p, splitmix64 &engine,
274             const unsigned min_depth, const unsigned max_depth, unsigned depth)
275 {
276     std::uniform_real_distribution<> rng01(0., 1.);
277     if (rng01(engine) < mut_p) {
278         e = generator(min_depth, max_depth, depth);
279     } else {
280         std::visit(
281             [&generator, &mut_p, &min_depth, &max_depth, &depth, &engine](auto &node) {
282                 if constexpr (std::is_same_v<decltype(node), func &>) {
283                     for (auto [b, e] = node.get_mutable_args_it(); b != e; ++b) {
284                         mutate(*b, generator, mut_p, engine, min_depth, max_depth, depth + 1);
285                     }
286                 }
287             },
288             e.value());
289     }
290 }
291 // Version targeting a node
mutate(expression & e,std::size_t node_id,const expression_generator & generator,const unsigned min_depth,const unsigned max_depth)292 void mutate(expression &e, std::size_t node_id, const expression_generator &generator, const unsigned min_depth,
293             const unsigned max_depth)
294 {
295     auto e_sub_ptr = fetch_from_node_id(e, node_id);
296     if (!e_sub_ptr) {
297         throw std::invalid_argument("The node id requested: " + std::to_string(node_id)
298                                     + " was not found in the expression e1: ");
299     }
300     *e_sub_ptr = generator(min_depth, max_depth);
301 }
302 
count_nodes(const expression & e)303 std::size_t count_nodes(const expression &e)
304 {
305     std::size_t node_counter = 0u;
306     detail::count_nodes_impl(e, node_counter);
307     return node_counter;
308 }
309 
fetch_from_node_id(expression & ex,std::size_t node_id)310 expression *fetch_from_node_id(expression &ex, std::size_t node_id)
311 {
312     std::size_t cur_id = 0;
313     expression *ret = nullptr;
314 
315     detail::fetch_from_node_id_impl(ex, node_id, cur_id, ret);
316 
317     return ret;
318 }
319 
320 // Crossover
crossover(expression & e1,expression & e2,splitmix64 & engine)321 void crossover(expression &e1, expression &e2, splitmix64 &engine)
322 {
323     std::uniform_int_distribution<std::size_t> t1(0, count_nodes(e1) - 1u);
324     std::uniform_int_distribution<std::size_t> t2(0, count_nodes(e2) - 1u);
325     auto node_id1 = t1(engine);
326     auto node_id2 = t2(engine);
327     auto e2_sub_ptr = fetch_from_node_id(e1, node_id1);
328     auto e1_sub_ptr = fetch_from_node_id(e2, node_id2);
329     assert(e2_sub_ptr != nullptr);
330     assert(e1_sub_ptr != nullptr);
331     swap(*e2_sub_ptr, *e1_sub_ptr);
332 }
333 
334 // Crossover targeting specific node_ids
crossover(expression & e1,expression & e2,std::size_t node_id1,std::size_t node_id2)335 void crossover(expression &e1, expression &e2, std::size_t node_id1, std::size_t node_id2)
336 {
337     auto e2_sub_ptr = fetch_from_node_id(e1, node_id1);
338     auto e1_sub_ptr = fetch_from_node_id(e2, node_id2);
339     if (!e1_sub_ptr) {
340         throw std::invalid_argument("The node id requested: " + std::to_string(node_id1)
341                                     + " was not found in the expression e1: ");
342     } else if (!e2_sub_ptr) {
343         throw std::invalid_argument("The node id requested: " + std::to_string(node_id2)
344                                     + " was not found in the expression e2: ");
345     }
346     swap(*e2_sub_ptr, *e1_sub_ptr);
347 }
348 
349 } // namespace heyoka
350