1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #include <cassert>
10 #include <cstddef>
11 #include <iostream>
12 #include <iterator>
13 #include <random>
14 #include <stdexcept>
15 #include <string>
16 #include <type_traits>
17 #include <variant>
18 #include <vector>
19
20 #include <heyoka/detail/type_traits.hpp>
21 #include <heyoka/expression.hpp>
22 #include <heyoka/func.hpp>
23 #include <heyoka/gp.hpp>
24 #include <heyoka/math.hpp>
25 #include <heyoka/number.hpp>
26 #include <heyoka/variable.hpp>
27
28 namespace heyoka
29 {
30
31 namespace detail
32 {
33
34 namespace
35 {
36
37 template <typename It, typename Rng>
38 It random_element(It start, It end, Rng &g)
39 {
40 std::uniform_int_distribution<decltype(std::distance(start, end))> dis(0, std::distance(start, end) - 1);
41 std::advance(start, dis(g));
42 return start;
43 }
44
fetch_from_node_id_impl(expression & ex,std::size_t node_id,std::size_t & node_counter,expression * & ret)45 void fetch_from_node_id_impl(expression &ex, std::size_t node_id, std::size_t &node_counter, expression *&ret)
46 {
47 if (node_counter == node_id) {
48 ret = &ex;
49 } else {
50 ++node_counter;
51 std::visit(
52 [node_id, &node_counter, &ret](auto &node) {
53 using type = detail::uncvref_t<decltype(node)>;
54 if constexpr (std::is_same_v<type, func>) {
55 for (auto [b, e] = node.get_mutable_args_it(); b != e; ++b) {
56 fetch_from_node_id_impl(*b, node_id, node_counter, ret);
57 if (ret) {
58 return;
59 }
60 }
61 }
62 },
63 ex.value());
64 }
65 }
66
count_nodes_impl(const expression & e,std::size_t & node_counter)67 void count_nodes_impl(const expression &e, std::size_t &node_counter)
68 {
69 ++node_counter;
70 std::visit(
71 [&node_counter](auto &node) {
72 using type = detail::uncvref_t<decltype(node)>;
73 if constexpr (std::is_same_v<type, func>) {
74 for (auto &arg : node.args()) {
75 count_nodes_impl(arg, node_counter);
76 }
77 }
78 },
79 e.value());
80 }
81
82 } // namespace
83
84 } // namespace detail
85
expression_generator(const std::vector<std::string> & vars,splitmix64 & engine)86 expression_generator::expression_generator(const std::vector<std::string> &vars, splitmix64 &engine)
87 : m_vars(vars), m_b_funcs(), m_e(engine)
88 {
89 // These are the default blocks for a random expression.
90 m_u_funcs = {heyoka::sin, heyoka::cos};
91 m_b_funcs = {heyoka::add, heyoka::sub, heyoka::mul, heyoka::div};
92 m_range_dbl = 10;
93 // Default weights for the probability of selecting a bo, unary, binary, a variable, a number.
94 m_weights = {8., 2., 1., 4., 1.};
95 };
96
operator ()(unsigned min_depth,unsigned max_depth,unsigned depth) const97 expression expression_generator::operator()(unsigned min_depth, unsigned max_depth, unsigned depth) const
98 {
99 std::uniform_real_distribution<double> rng01(0.0, 1.0);
100 std::uniform_real_distribution<double> rngm11(-1.0, 1.0);
101
102 // First we decide what node type this will be.
103 node_type type;
104 if (depth < min_depth) {
105 // If the node depth is below the minimum desired, we force leaves (num or var) to be not selected
106 double n_u_fun = static_cast<double>(m_u_funcs.size());
107 double n_b_fun = static_cast<double>(m_b_funcs.size());
108 std::discrete_distribution<> dis({0, n_u_fun * m_weights[1], n_b_fun * m_weights[2]});
109 switch (dis(m_e)) {
110 case 0:
111 type = node_type::bo;
112 break;
113 case 1:
114 type = node_type::u_fun;
115 break;
116 case 2:
117 type = node_type::b_fun;
118 break;
119 }
120 } else if (depth >= max_depth) {
121 // If the node depth is above the maximum desired, we force leaves (num or var) to be selected
122 double n_var = static_cast<double>(m_vars.size());
123 std::discrete_distribution<> dis({n_var * m_weights[3], m_weights[4]});
124 switch (dis(m_e)) {
125 case 0:
126 type = node_type::var;
127 break;
128 case 1:
129 type = node_type::num;
130 break;
131 }
132 } else {
133 // else we can get anything
134 double n_u_fun = static_cast<double>(m_u_funcs.size());
135 double n_b_fun = static_cast<double>(m_b_funcs.size());
136 double n_var = static_cast<double>(m_vars.size());
137 std::discrete_distribution<> dis(
138 {0, n_u_fun * m_weights[1], n_b_fun * m_weights[2], n_var * m_weights[3], m_weights[4]});
139 switch (dis(m_e)) {
140 case 0:
141 type = node_type::bo;
142 break;
143 case 1:
144 type = node_type::u_fun;
145 break;
146 case 2:
147 type = node_type::b_fun;
148 break;
149 case 3:
150 type = node_type::var;
151 break;
152 case 4:
153 type = node_type::num;
154 break;
155 }
156 }
157 // Once we know the node type we create one at random out of the user defined possible choices
158 switch (type) {
159 case node_type::num: {
160 // We return a random number in -m_range_dbl, m_range_dbl
161 auto value = rngm11(m_e) * m_range_dbl;
162 return expression{number{value}};
163 break;
164 }
165 case node_type::var: {
166 // We return one of the variables in m_vars
167 auto symbol = *detail::random_element(m_vars.begin(), m_vars.end(), m_e);
168 return expression{variable{symbol}};
169 break;
170 }
171 case node_type::u_fun: {
172 // We return one of the unary functions in m_u_funcs with randomly constructed argument
173 auto u_f = *detail::random_element(m_u_funcs.begin(), m_u_funcs.end(), m_e);
174 return u_f(this->operator()(min_depth, max_depth, depth + 1));
175 break;
176 }
177 case node_type::b_fun: {
178 // We return one of the binary functions in m_b_funcs with randomly constructed arguments
179 auto b_f = *detail::random_element(m_b_funcs.begin(), m_b_funcs.end(), m_e);
180 return b_f(this->operator()(min_depth, max_depth, depth + 1),
181 this->operator()(min_depth, max_depth, depth + 1));
182 break;
183 }
184 default:
185 throw;
186 }
187 };
188
get_u_funcs() const189 const std::vector<expression (*)(expression)> &expression_generator::get_u_funcs() const
190 {
191 return m_u_funcs;
192 }
get_b_funcs() const193 const std::vector<expression (*)(expression, expression)> &expression_generator::get_b_funcs() const
194 {
195 return m_b_funcs;
196 }
get_vars() const197 const std::vector<std::string> &expression_generator::get_vars() const
198 {
199 return m_vars;
200 }
get_range_dbl() const201 const double &expression_generator::get_range_dbl() const
202 {
203 return m_range_dbl;
204 }
get_weights() const205 const std::vector<double> &expression_generator::get_weights() const
206 {
207 return m_weights;
208 }
209
set_u_funcs(const std::vector<expression (*)(expression)> & u_funcs)210 void expression_generator::set_u_funcs(const std::vector<expression (*)(expression)> &u_funcs)
211 {
212 m_u_funcs = u_funcs;
213 }
set_b_funcs(const std::vector<expression (*)(expression,expression)> & b_funcs)214 void expression_generator::set_b_funcs(const std::vector<expression (*)(expression, expression)> &b_funcs)
215 {
216 m_b_funcs = b_funcs;
217 }
set_vars(const std::vector<std::string> & vars)218 void expression_generator::set_vars(const std::vector<std::string> &vars)
219 {
220 m_vars = vars;
221 }
set_range_dbl(const double & rd)222 void expression_generator::set_range_dbl(const double &rd)
223 {
224 m_range_dbl = rd;
225 }
set_weights(const std::vector<double> & w)226 void expression_generator::set_weights(const std::vector<double> &w)
227 {
228 if (w.size() != 5) {
229 throw std::invalid_argument(
230 "The weight vector for the probablity distribution of the node type must have size "
231 "5 -> (binary operator, unary functions, binary functions, variable, numbers), while I detected a size of: "
232 + std::to_string(w.size()));
233 }
234 m_weights = w;
235 }
236
operator <<(std::ostream & os,const expression_generator & eg)237 std::ostream &operator<<(std::ostream &os, const expression_generator &eg)
238 {
239 os << "Expression Generator:";
240 auto vars = eg.get_vars();
241
242 os << "\nVariables: ";
243 for (const auto &var : vars) {
244 os << var << " ";
245 }
246 auto u_funcs = eg.get_u_funcs();
247 if (u_funcs.size()) {
248 os << "\nUnary Functions: ";
249 for (const auto &u_func : u_funcs) {
250 os << u_func("."_var) << " ";
251 }
252 }
253 auto b_funcs = eg.get_b_funcs();
254 if (b_funcs.size()) {
255 os << "\nBinary Functions: ";
256 for (const auto &b_func : b_funcs) {
257 os << b_func("."_var, "."_var) << " ";
258 }
259 }
260 os << "\nRandom double constants range: ";
261 os << "[-" << eg.get_range_dbl() << ", " << eg.get_range_dbl() << "]";
262 os << "\nWeights:";
263 os << "\n\tBinary operator: " << eg.get_weights()[0];
264 os << "\n\tUnary function: " << eg.get_weights()[1];
265 os << "\n\tBinary function: " << eg.get_weights()[2];
266 os << "\n\tVariable: " << eg.get_weights()[3];
267 os << "\n\tConstant: " << eg.get_weights()[4];
268
269 return os << "\n";
270 }
271
272 // Version randomly selecting nodes during traversal (PROBABLY WILL BE REMOVED)
mutate(expression & e,const expression_generator & generator,const double mut_p,splitmix64 & engine,const unsigned min_depth,const unsigned max_depth,unsigned depth)273 void mutate(expression &e, const expression_generator &generator, const double mut_p, splitmix64 &engine,
274 const unsigned min_depth, const unsigned max_depth, unsigned depth)
275 {
276 std::uniform_real_distribution<> rng01(0., 1.);
277 if (rng01(engine) < mut_p) {
278 e = generator(min_depth, max_depth, depth);
279 } else {
280 std::visit(
281 [&generator, &mut_p, &min_depth, &max_depth, &depth, &engine](auto &node) {
282 if constexpr (std::is_same_v<decltype(node), func &>) {
283 for (auto [b, e] = node.get_mutable_args_it(); b != e; ++b) {
284 mutate(*b, generator, mut_p, engine, min_depth, max_depth, depth + 1);
285 }
286 }
287 },
288 e.value());
289 }
290 }
291 // Version targeting a node
mutate(expression & e,std::size_t node_id,const expression_generator & generator,const unsigned min_depth,const unsigned max_depth)292 void mutate(expression &e, std::size_t node_id, const expression_generator &generator, const unsigned min_depth,
293 const unsigned max_depth)
294 {
295 auto e_sub_ptr = fetch_from_node_id(e, node_id);
296 if (!e_sub_ptr) {
297 throw std::invalid_argument("The node id requested: " + std::to_string(node_id)
298 + " was not found in the expression e1: ");
299 }
300 *e_sub_ptr = generator(min_depth, max_depth);
301 }
302
count_nodes(const expression & e)303 std::size_t count_nodes(const expression &e)
304 {
305 std::size_t node_counter = 0u;
306 detail::count_nodes_impl(e, node_counter);
307 return node_counter;
308 }
309
fetch_from_node_id(expression & ex,std::size_t node_id)310 expression *fetch_from_node_id(expression &ex, std::size_t node_id)
311 {
312 std::size_t cur_id = 0;
313 expression *ret = nullptr;
314
315 detail::fetch_from_node_id_impl(ex, node_id, cur_id, ret);
316
317 return ret;
318 }
319
320 // Crossover
crossover(expression & e1,expression & e2,splitmix64 & engine)321 void crossover(expression &e1, expression &e2, splitmix64 &engine)
322 {
323 std::uniform_int_distribution<std::size_t> t1(0, count_nodes(e1) - 1u);
324 std::uniform_int_distribution<std::size_t> t2(0, count_nodes(e2) - 1u);
325 auto node_id1 = t1(engine);
326 auto node_id2 = t2(engine);
327 auto e2_sub_ptr = fetch_from_node_id(e1, node_id1);
328 auto e1_sub_ptr = fetch_from_node_id(e2, node_id2);
329 assert(e2_sub_ptr != nullptr);
330 assert(e1_sub_ptr != nullptr);
331 swap(*e2_sub_ptr, *e1_sub_ptr);
332 }
333
334 // Crossover targeting specific node_ids
crossover(expression & e1,expression & e2,std::size_t node_id1,std::size_t node_id2)335 void crossover(expression &e1, expression &e2, std::size_t node_id1, std::size_t node_id2)
336 {
337 auto e2_sub_ptr = fetch_from_node_id(e1, node_id1);
338 auto e1_sub_ptr = fetch_from_node_id(e2, node_id2);
339 if (!e1_sub_ptr) {
340 throw std::invalid_argument("The node id requested: " + std::to_string(node_id1)
341 + " was not found in the expression e1: ");
342 } else if (!e2_sub_ptr) {
343 throw std::invalid_argument("The node id requested: " + std::to_string(node_id2)
344 + " was not found in the expression e2: ");
345 }
346 swap(*e2_sub_ptr, *e1_sub_ptr);
347 }
348
349 } // namespace heyoka
350