1# -*- coding: utf-8 -*-
2
3import math
4
5import numpy as np
6
7from pyfr.integrators.std.base import BaseStdIntegrator
8from pyfr.mpiutil import get_comm_rank_root, get_mpi
9
10
11class BaseStdController(BaseStdIntegrator):
12    def __init__(self, *args, **kwargs):
13        super().__init__(*args, **kwargs)
14
15        # Solution filtering frequency
16        self._fnsteps = self.cfg.getint('soln-filter', 'nsteps', '0')
17
18        # Stats on the most recent step
19        self.stepinfo = []
20
21        # Fire off any event handlers if not restarting
22        if not self.isrestart:
23            self.completed_step_handlers(self)
24
25    def _accept_step(self, dt, idxcurr, err=None):
26        self.tcurr += dt
27        self.nacptsteps += 1
28        self.nacptchain += 1
29        self.stepinfo.append((dt, 'accept', err))
30
31        self._idxcurr = idxcurr
32
33        # Filter
34        if self._fnsteps and self.nacptsteps % self._fnsteps == 0:
35            self.system.filt(idxcurr)
36
37        # Invalidate the solution cache
38        self._curr_soln = None
39
40        # Invalidate the solution gradients cache
41        self._curr_grad_soln = None
42
43        # Fire off any event handlers
44        self.completed_step_handlers(self)
45
46        # Abort if plugins request it
47        self._check_abort()
48
49        # Clear the step info
50        self.stepinfo = []
51
52    def _reject_step(self, dt, idxold, err=None):
53        if dt <= self.dtmin:
54            raise RuntimeError('Minimum sized time step rejected')
55
56        self.nacptchain = 0
57        self.nrjctsteps += 1
58        self.stepinfo.append((dt, 'reject', err))
59
60        self._idxcurr = idxold
61
62
63class StdNoneController(BaseStdController):
64    controller_name = 'none'
65
66    @property
67    def controller_needs_errest(self):
68        return False
69
70    def advance_to(self, t):
71        if t < self.tcurr:
72            raise ValueError('Advance time is in the past')
73
74        while self.tcurr < t:
75            # Decide on the time step
76            dt = max(min(t - self.tcurr, self._dt), self.dtmin)
77
78            # Take the step
79            idxcurr = self.step(self.tcurr, dt)
80
81            # We are not adaptive, so accept every step
82            self._accept_step(dt, idxcurr)
83
84
85class StdPIController(BaseStdController):
86    controller_name = 'pi'
87
88    def __init__(self, *args, **kwargs):
89        super().__init__(*args, **kwargs)
90
91        sect = 'solver-time-integrator'
92
93        # Maximum time step
94        self.dtmax = self.cfg.getfloat(sect, 'dt-max', 1e2)
95
96        # Error tolerances
97        self._atol = self.cfg.getfloat(sect, 'atol')
98        self._rtol = self.cfg.getfloat(sect, 'rtol')
99
100        # Error norm
101        self._norm = self.cfg.get(sect, 'errest-norm', 'l2')
102        if self._norm not in {'l2', 'uniform'}:
103            raise ValueError('Invalid error norm')
104
105        # PI control values
106        self._alpha = self.cfg.getfloat(sect, 'pi-alpha', 0.7)
107        self._beta = self.cfg.getfloat(sect, 'pi-beta', 0.4)
108
109        # Estimate of previous error
110        self._errprev = 1.0
111
112        # Step size adjustment factors
113        self._saffac = self.cfg.getfloat(sect, 'safety-fact', 0.8)
114        self._maxfac = self.cfg.getfloat(sect, 'max-fact', 2.5)
115        self._minfac = self.cfg.getfloat(sect, 'min-fact', 0.3)
116
117        if not self._minfac < 1 <= self._maxfac:
118            raise ValueError('Invalid max-fact, min-fact')
119
120    @property
121    def controller_needs_errest(self):
122        return True
123
124    def _errest(self, rcurr, rprev, rerr):
125        comm, rank, root = get_comm_rank_root()
126
127        errest = self._get_reduction_kerns(rcurr, rprev, rerr, method='errest',
128                                           norm=self._norm)
129
130        # Obtain an estimate for the squared error
131        self._queue.enqueue_and_run(errest, self._atol, self._rtol)
132
133        # L2 norm
134        if self._norm == 'l2':
135            # Reduce locally (element types + field variables)
136            err = np.array([sum(v for e in errest for v in e.retval)])
137
138            # Reduce globally (MPI ranks)
139            comm.Allreduce(get_mpi('in_place'), err, op=get_mpi('sum'))
140
141            # Normalise
142            err = math.sqrt(float(err) / self._gndofs)
143        # L^∞ norm
144        else:
145            # Reduce locally (element types + field variables)
146            err = np.array([max(v for e in errest for v in e.retval)])
147
148            # Reduce globally (MPI ranks)
149            comm.Allreduce(get_mpi('in_place'), err, op=get_mpi('max'))
150
151            # Normalise
152            err = math.sqrt(float(err))
153
154        return err if not math.isnan(err) else 100
155
156    def advance_to(self, t):
157        if t < self.tcurr:
158            raise ValueError('Advance time is in the past')
159
160        # Constants
161        maxf = self._maxfac
162        minf = self._minfac
163        saff = self._saffac
164        sord = self.stepper_order
165
166        expa = self._alpha / sord
167        expb = self._beta / sord
168
169        while self.tcurr < t:
170            # Decide on the time step
171            dt = max(min(t - self.tcurr, self._dt, self.dtmax), self.dtmin)
172
173            # Take the step
174            idxcurr, idxprev, idxerr = self.step(self.tcurr, dt)
175
176            # Estimate the error
177            err = self._errest(idxcurr, idxprev, idxerr)
178
179            # Determine time step adjustment factor
180            fac = err**-expa * self._errprev**expb
181            fac = min(maxf, max(minf, saff*fac))
182
183            # Compute the size of the next step
184            self._dt = fac*dt
185
186            # Decide if to accept or reject the step
187            if err < 1.0:
188                self._errprev = err
189                self._accept_step(dt, idxcurr, err=err)
190            else:
191                self._reject_step(dt, idxprev, err=err)
192