1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #include <heyoka/config.hpp>
10
11 #include <cassert>
12 #include <cstddef>
13 #include <functional>
14 #include <ostream>
15 #include <stdexcept>
16 #include <string>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <utility>
20 #include <vector>
21
22 #include <fmt/format.h>
23
24 #if defined(HEYOKA_HAVE_REAL128)
25
26 #include <mp++/real128.hpp>
27
28 #endif
29
30 #include <heyoka/expression.hpp>
31 #include <heyoka/number.hpp>
32 #include <heyoka/variable.hpp>
33
34 namespace heyoka
35 {
36
variable()37 variable::variable() : variable("") {}
38
variable(std::string s)39 variable::variable(std::string s) : m_name(std::move(s)) {}
40
41 variable::variable(const variable &) = default;
42
43 variable::variable(variable &&) noexcept = default;
44
45 variable::~variable() = default;
46
47 variable &variable::operator=(const variable &) = default;
48
49 variable &variable::operator=(variable &&) noexcept = default;
50
name()51 std::string &variable::name()
52 {
53 return m_name;
54 }
55
name() const56 const std::string &variable::name() const
57 {
58 return m_name;
59 }
60
swap(variable & v0,variable & v1)61 void swap(variable &v0, variable &v1) noexcept
62 {
63 std::swap(v0.name(), v1.name());
64 }
65
hash(const variable & v)66 std::size_t hash(const variable &v)
67 {
68 return std::hash<std::string>{}(v.name());
69 }
70
operator <<(std::ostream & os,const variable & var)71 std::ostream &operator<<(std::ostream &os, const variable &var)
72 {
73 return os << var.name();
74 }
75
operator ==(const variable & v1,const variable & v2)76 bool operator==(const variable &v1, const variable &v2)
77 {
78 return v1.name() == v2.name();
79 }
80
operator !=(const variable & v1,const variable & v2)81 bool operator!=(const variable &v1, const variable &v2)
82 {
83 return !(v1 == v2);
84 }
85
eval_dbl(const variable & var,const std::unordered_map<std::string,double> & map,const std::vector<double> &)86 double eval_dbl(const variable &var, const std::unordered_map<std::string, double> &map, const std::vector<double> &)
87 {
88 using namespace fmt::literals;
89 if (auto it = map.find(var.name()); it != map.end()) {
90 return it->second;
91 } else {
92 throw std::invalid_argument(
93 "Cannot evaluate the variable '{}' because it is missing from the evaluation map"_format(var.name()));
94 }
95 }
96
eval_ldbl(const variable & var,const std::unordered_map<std::string,long double> & map,const std::vector<long double> &)97 long double eval_ldbl(const variable &var, const std::unordered_map<std::string, long double> &map,
98 const std::vector<long double> &)
99 {
100 using namespace fmt::literals;
101 if (auto it = map.find(var.name()); it != map.end()) {
102 return it->second;
103 } else {
104 throw std::invalid_argument(
105 "Cannot evaluate the variable '{}' because it is missing from the evaluation map"_format(var.name()));
106 }
107 }
108
109 #if defined(HEYOKA_HAVE_REAL128)
110
eval_f128(const variable & var,const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> &)111 mppp::real128 eval_f128(const variable &var, const std::unordered_map<std::string, mppp::real128> &map,
112 const std::vector<mppp::real128> &)
113 {
114 using namespace fmt::literals;
115 if (auto it = map.find(var.name()); it != map.end()) {
116 return it->second;
117 } else {
118 throw std::invalid_argument(
119 "Cannot evaluate the variable '{}' because it is missing from the evaluation map"_format(var.name()));
120 }
121 }
122
123 #endif
124
eval_batch_dbl(std::vector<double> & out_values,const variable & var,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> &)125 void eval_batch_dbl(std::vector<double> &out_values, const variable &var,
126 const std::unordered_map<std::string, std::vector<double>> &map, const std::vector<double> &)
127 {
128 if (auto it = map.find(var.name()); it != map.end()) {
129 out_values = it->second;
130 } else {
131 throw std::invalid_argument("Cannot evaluate the variable '" + var.name()
132 + "' because it is missing from the evaluation map");
133 }
134 }
135
update_connections(std::vector<std::vector<std::size_t>> & node_connections,const variable &,std::size_t & node_counter)136 void update_connections(std::vector<std::vector<std::size_t>> &node_connections, const variable &,
137 std::size_t &node_counter)
138 {
139 node_connections.push_back(std::vector<std::size_t>());
140 node_counter++;
141 }
142
update_node_values_dbl(std::vector<double> & node_values,const variable & var,const std::unordered_map<std::string,double> & map,const std::vector<std::vector<std::size_t>> &,std::size_t & node_counter)143 void update_node_values_dbl(std::vector<double> &node_values, const variable &var,
144 const std::unordered_map<std::string, double> &map,
145 const std::vector<std::vector<std::size_t>> &, std::size_t &node_counter)
146 {
147 if (auto it = map.find(var.name()); it != map.end()) {
148 node_values[node_counter] = it->second;
149 } else {
150 throw std::invalid_argument("Cannot update the node output for the variable '" + var.name()
151 + "' because it is missing from the evaluation map");
152 }
153 node_counter++;
154 }
155
update_grad_dbl(std::unordered_map<std::string,double> & grad,const variable & var,const std::unordered_map<std::string,double> &,const std::vector<double> &,const std::vector<std::vector<std::size_t>> &,std::size_t & node_counter,double acc)156 void update_grad_dbl(std::unordered_map<std::string, double> &grad, const variable &var,
157 const std::unordered_map<std::string, double> &, const std::vector<double> &,
158 const std::vector<std::vector<std::size_t>> &, std::size_t &node_counter, double acc)
159 {
160 grad[var.name()] = grad[var.name()] + acc;
161 node_counter++;
162 }
163
164 } // namespace heyoka
165