1 // Copyright (c) 2016-2019 Anyar, Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 
17 #include "ascent/direct/State.h"
18 
19 namespace asc
20 {
21    template <typename T>
22    struct is_pair : std::false_type { };
23 
24    template <typename T, typename U>
25    struct is_pair<std::pair<T, U>> : std::true_type { };
26 
27    template <typename T>
28    constexpr bool is_pair_v = is_pair<T>::value;
29 
30    template <class value_t>
31    struct Propagator
32    {
33       Propagator() = default;
34       Propagator(const Propagator&) = default;
35       Propagator(Propagator&&) = default;
36       Propagator& operator=(const Propagator&) = default;
37       Propagator& operator=(Propagator&&) = default;
38       virtual ~Propagator() {}
39 
40       virtual void operator()(State&, const double) = 0; // inputs: state, dt (time step)
41 
42       size_t pass{};
43    };
44 
45    enum struct Phase
46    {
47       Link,
48       Init,
49       Update,
50       Postprop,
51       Postcalc
52    };
53 
54    struct Module
55    {
56       Module() = default;
57       Module(const Module&) = default;
58       Module(Module&&) = default;
59       Module& operator=(const Module&) = default;
60       Module& operator=(Module&&) = default;
61       virtual ~Module() = default;
62 
63       std::vector<State> states;
64 
65       template <class x_t, class xd_t>
66       void make_state(x_t& x, xd_t& xd)
67       {
68          states.emplace_back(x, xd);
69       }
70 
71       template <class x_t, class xd_t>
72       void make_states(x_t& x, xd_t& xd)
73       {
74          const size_t n = x.size();
75          for (size_t i = 0; i < n; ++i)
76          {
77             states.emplace_back(x[i], xd[i]);
78          }
79       }
80 
81       template <class data_t>
82       void make_states(data_t* x, data_t* xd, const size_t n)
83       {
84          for (size_t i = 0; i < n; ++i) {
85             states.emplace_back(x[i], xd[i]);
86          }
87       }
88 
89       template <class states_t>
90       void add_states(states_t& ext_states)
91       {
92          for (auto& state : states) {
93             ext_states.emplace_back(state);
94          }
95       }
96 
97       virtual void link() {} // linking modules
98       virtual void init() {} // initialization
99       virtual void operator()() {} // derivative accumulation
100       virtual void apply() {} // apply accumulations
101       virtual void propagate(Propagator<double>& propagator, const double dt)
102       {
103          for (auto& state : states) {
104             propagator(state, dt);
105          }
106       }
107       virtual void postprop() {} // post propagation calculations (every substep)
108       virtual void postcalc() {} // post integration calculations (every full step)
109 
110       bool init_called = false;
111    };
112 
113    template <class modules_t>
114    inline void init(modules_t& blocks)
115    {
116       for (auto& block : blocks)
117       {
118          if (!block->init_called)
119          {
120             block->init();
121             block->init_called = true;
122          }
123       }
124    }
125 
126    template <void(Module::* func)(), class modules_t>
127    void call_loop(modules_t& blocks)
128    {
129       if constexpr (is_pair_v<typename std::iterator_traits<typename modules_t::iterator>::value_type>)
130       {
131          for (auto& block : blocks) {
132             (block.second->*func)();
133          }
134       }
135       else
136       {
137          for (auto& block : blocks) {
138             (block->*func)();
139          }
140       }
141    }
142 
143    template <class modules_t>
144    void update(modules_t& blocks)
145    {
146       call_loop<&Module::operator()>(blocks);
147    }
148 
149    template <class modules_t>
150    void update(modules_t& blocks, asc::Module* run_first)
151    {
152       if (run_first) {
153          (*run_first)();
154       }
155       update(blocks);
156    }
157 
158    template <class modules_t>
159    void apply(modules_t& blocks)
160    {
161       call_loop<&Module::apply>(blocks);
162    }
163 
164    template <class modules_t, class propagator_t, class value_t>
165    void propagate(modules_t& blocks, propagator_t& propagator, const value_t dt)
166    {
167       if constexpr (is_pair_v<typename std::iterator_traits<typename modules_t::iterator>::value_type>)
168       {
169          for (auto& block : blocks) {
170             for (auto &state : block.second->states) {
171                if (propagator.pass == 0 && state.hist_len > 0) {
172                   state.x0_hist.push_back(*state.x);
173                   if (state.x0_hist.size() > state.hist_len) state.x0_hist.pop_front();
174 
175                   state.xd0_hist.push_back(*state.xd);
176                   if (state.xd0_hist.size() > state.hist_len) state.x0_hist.pop_front();
177                }
178             }
179             block.second->propagate(propagator, dt);
180          }
181       }
182       else
183       {
184          for (auto& block : blocks) {
185             for (auto &state : block->states) {
186                if (propagator.pass == 0 && state.hist_len > 0) {
187                   state.x0_hist.push_back(*state.x);
188                   if (state.x0_hist.size() > state.hist_len) state.x0_hist.pop_front();
189 
190                   state.xd0_hist.push_back(*state.xd);
191                   if (state.xd0_hist.size() > state.hist_len) state.x0_hist.pop_front();
192                }
193             }
194             block->propagate(propagator, dt);
195          }
196       }
197    }
198 
199    template <class modules_t>
200    void postprop(modules_t& blocks)
201    {
202       call_loop<&Module::postprop>(blocks);
203    }
204 
205    template <class modules_t>
206    inline void postcalc(modules_t& blocks)
207    {
208       call_loop<&Module::postcalc>(blocks);
209    }
210 
211    template <class states_t, class ptr_t>
212    inline void add_states(states_t& states, std::vector<ptr_t>& blocks)
213    {
214       for (auto& block : blocks)
215       {
216          auto& m_states = block->states;
217          for (auto& state : m_states) {
218             states.emplace_back(state);
219          }
220       }
221    }
222 
223    template <class states_t, class ptr_t>
224    inline void add_states(states_t& states, ptr_t& block)
225    {
226       for (auto& state : block->states) {
227          states.emplace_back(state);
228       }
229    }
230 }