1# -*- coding: utf-8 -*-
2
3import re
4
5import numpy as np
6
7from pyfr.inifile import Inifile
8from pyfr.mpiutil import get_comm_rank_root
9from pyfr.nputil import npeval
10from pyfr.plugins.base import BasePlugin, PostactionMixin, RegionMixin
11from pyfr.writers.native import NativeWriter
12
13
14class TavgPlugin(PostactionMixin, RegionMixin, BasePlugin):
15    name = 'tavg'
16    systems = ['*']
17    formulations = ['dual', 'std']
18
19    def __init__(self, intg, cfgsect, suffix=None):
20        super().__init__(intg, cfgsect, suffix)
21
22        # Underlying elements class
23        self.elementscls = intg.system.elementscls
24
25        # Averaging mode
26        self.mode = self.cfg.get(cfgsect, 'mode', 'windowed')
27        if self.mode not in {'continuous', 'windowed'}:
28            raise ValueError('Invalid averaging mode')
29
30        # Expressions pre-processing
31        self._prepare_exprs()
32
33        # Output data type
34        fpdtype = self.cfg.get(cfgsect, 'precision', 'single')
35        if fpdtype == 'single':
36            self.fpdtype = np.float32
37        elif fpdtype == 'double':
38            self.fpdtype = np.float64
39        else:
40            raise ValueError('Invalid floating point data type')
41
42        # Base output directory and file name
43        basedir = self.cfg.getpath(self.cfgsect, 'basedir', '.', abs=True)
44        basename = self.cfg.get(self.cfgsect, 'basename')
45
46        # Construct the file writer
47        self._writer = NativeWriter(intg, basedir, basename, 'tavg')
48
49        # Gradient pre-processing
50        self._init_gradients(intg)
51
52        # Time averaging parameters
53        self.tstart = self.cfg.getfloat(cfgsect, 'tstart', 0.0)
54        self.dtout = self.cfg.getfloat(cfgsect, 'dt-out')
55        self.nsteps = self.cfg.getint(cfgsect, 'nsteps')
56
57        # Register our output times with the integrator
58        intg.call_plugin_dt(self.dtout)
59
60        # Mark ourselves as not currently averaging
61        self._started = False
62
63    def _prepare_exprs(self):
64        cfg, cfgsect = self.cfg, self.cfgsect
65        c = self.cfg.items_as('constants', float)
66        self.anames, self.aexprs = [], []
67        self.outfields, self.fexprs = [], []
68
69        # Iterate over accumulation expressions first
70        for k in cfg.items(cfgsect):
71            if k.startswith('avg-'):
72                self.anames.append(k[4:])
73                self.aexprs.append(cfg.getexpr(cfgsect, k, subs=c))
74                self.outfields.append(k)
75
76        # Followed by any functional expressions
77        for k in cfg.items(cfgsect):
78            if k.startswith('fun-avg-'):
79                self.fexprs.append(cfg.getexpr(cfgsect, k, subs=c))
80                self.outfields.append(k)
81
82    def _init_gradients(self, intg):
83        # Determine what gradients, if any, are required
84        gradpnames = set()
85        for ex in self.aexprs:
86            gradpnames.update(re.findall(r'\bgrad_(.+?)_[xyz]\b', ex))
87
88        privarmap = self.elementscls.privarmap[self.ndims]
89        self._gradpinfo = [(pname, privarmap.index(pname))
90                           for pname in gradpnames]
91
92    def _init_accumex(self, intg):
93        self.prevt = self.tout_last = intg.tcurr
94        self.prevex = self._eval_acc_exprs(intg)
95        self.accex = [np.zeros_like(p, dtype=np.float64) for p in self.prevex]
96
97        # Extra state for continuous accumulation
98        if self.mode == 'continuous':
99            self.caccex = [np.zeros_like(a) for a in self.accex]
100            self.tstart_actual = intg.tcurr
101
102    def _eval_acc_exprs(self, intg):
103        exprs = []
104
105        # Get the primitive variable names
106        pnames = self.elementscls.privarmap[self.ndims]
107
108        # Iterate over each element type in the simulation
109        for idx, etype, rgn in self._ele_regions:
110            soln = intg.soln[idx][..., rgn].swapaxes(0, 1)
111
112            # Convert from conservative to primitive variables
113            psolns = self.elementscls.con_to_pri(soln, self.cfg)
114
115            # Prepare the substitutions dictionary
116            subs = dict(zip(pnames, psolns))
117
118            # Prepare any required gradients
119            if self._gradpinfo:
120                # Compute the gradients
121                grad_soln = np.rollaxis(intg.grad_soln[idx], 2)[..., rgn]
122
123                # Transform from conservative to primitive gradients
124                pgrads = self.elementscls.grad_con_to_pri(soln, grad_soln,
125                                                          self.cfg)
126
127                # Add them to the substitutions dictionary
128                for pname, idx in self._gradpinfo:
129                    for dim, grad in zip('xyz', pgrads[idx]):
130                        subs[f'grad_{pname}_{dim}'] = grad
131
132            # Evaluate the expressions
133            exprs.append([npeval(v, subs) for v in self.aexprs])
134
135        # Stack up the expressions for each element type and return
136        return [np.dstack(exs).swapaxes(1, 2) for exs in exprs]
137
138    def _eval_fun_exprs(self, intg, accex):
139        exprs = []
140
141        # Iterate over each element type our averaging region
142        for avals in accex:
143            # Prepare the substitution dictionary
144            subs = dict(zip(self.anames, avals.swapaxes(0, 1)))
145
146            exprs.append([npeval(v, subs) for v in self.fexprs])
147
148        # Stack up the expressions for each element type and return
149        return [np.dstack(exs).swapaxes(1, 2) for exs in exprs]
150
151    def __call__(self, intg):
152        # If we are not supposed to be averaging yet then return
153        if intg.tcurr < self.tstart:
154            return
155
156        # If necessary, run the start-up routines
157        if not self._started:
158            self._init_accumex(intg)
159            self._started = True
160
161        # See if we are due to write and/or accumulate this step
162        dowrite = intg.tcurr - self.tout_last >= self.dtout - self.tol
163        doaccum = intg.nacptsteps % self.nsteps == 0
164
165        if dowrite or doaccum:
166            # Evaluate the time averaging expressions
167            currex = self._eval_acc_exprs(intg)
168
169            # Accumulate them; always do this even when just writing
170            for a, p, c in zip(self.accex, self.prevex, currex):
171                a += 0.5*(intg.tcurr - self.prevt)*(p + c)
172
173            # Save the time and solution
174            self.prevt = intg.tcurr
175            self.prevex = currex
176
177            if dowrite:
178                comm, rank, root = get_comm_rank_root()
179
180                if self.mode == 'windowed':
181                    accex = self.accex
182                    tstart = self.tout_last
183                else:
184                    for a, c in zip(self.accex, self.caccex):
185                        c += a
186
187                    accex = self.caccex
188                    tstart = self.tstart_actual
189
190                # Normalise the accumulated expressions
191                tavg = [a / (intg.tcurr - tstart) for a in accex]
192
193                # Evaluate any functional expressions
194                if self.fexprs:
195                    funex = self._eval_fun_exprs(intg, tavg)
196                    tavg = [np.hstack([a, f]) for a, f in zip(tavg, funex)]
197
198                # Form the output records to be written to disk
199                data = dict(self._ele_region_data)
200                for (idx, etype, rgn), d in zip(self._ele_regions, tavg):
201                    data[etype] = d.astype(self.fpdtype)
202
203                # If we are the root rank then prepare the metadata
204                if rank == root:
205                    stats = Inifile()
206                    stats.set('data', 'prefix', 'tavg')
207                    stats.set('data', 'fields', ','.join(self.outfields))
208                    stats.set('tavg', 'tstart', tstart)
209                    stats.set('tavg', 'tend', intg.tcurr)
210                    intg.collect_stats(stats)
211
212                    metadata = dict(intg.cfgmeta,
213                                    stats=stats.tostr(),
214                                    mesh_uuid=intg.mesh_uuid)
215                else:
216                    metadata = None
217
218                # Write to disk
219                solnfname = self._writer.write(data, intg.tcurr, metadata)
220
221                # If a post-action has been registered then invoke it
222                self._invoke_postaction(intg=intg, mesh=intg.system.mesh.fname,
223                                        soln=solnfname, t=intg.tcurr)
224
225                # Reset the accumulators
226                for a in self.accex:
227                    a.fill(0)
228
229                self.tout_last = intg.tcurr
230