1# -*- coding: utf-8 -*-
2
3import numpy as np
4
5from pyfr.mpiutil import get_comm_rank_root, get_mpi
6from pyfr.plugins.base import BasePlugin, init_csv
7
8
9def _closest_upts_bf(etypes, eupts, pts):
10    for p in pts:
11        # Compute the distances between each point and p
12        dists = [np.linalg.norm(e - p, axis=2) for e in eupts]
13
14        # Get the index of the closest point to p for each element type
15        amins = [np.unravel_index(np.argmin(d), d.shape) for d in dists]
16
17        # Dereference to get the actual distances and locations
18        dmins = [d[a] for d, a in zip(dists, amins)]
19        plocs = [e[a] for e, a in zip(eupts, amins)]
20
21        # Find the minimum across all element types
22        yield min(zip(dmins, plocs, etypes, amins))
23
24
25def _closest_upts_kd(etypes, eupts, pts):
26    from scipy.spatial import cKDTree
27
28    # Flatten the physical location arrays
29    feupts = [e.reshape(-1, e.shape[-1]) for e in eupts]
30
31    # For each element type construct a KD-tree of the upt locations
32    trees = [cKDTree(f) for f in feupts]
33
34    for p in pts:
35        # Query the distance/index of the closest upt to p
36        dmins, amins = zip(*[t.query(p) for t in trees])
37
38        # Unravel the indices
39        amins = [np.unravel_index(i, e.shape[:2])
40                 for i, e in zip(amins, eupts)]
41
42        # Dereference to obtain the precise locations
43        plocs = [e[a] for e, a in zip(eupts, amins)]
44
45        # Reduce across element types
46        yield min(zip(dmins, plocs, etypes, amins))
47
48
49def _closest_upts(etypes, eupts, pts):
50    try:
51        # Attempt to use a KD-tree based approach
52        yield from _closest_upts_kd(etypes, eupts, pts)
53    except ImportError:
54        # Otherwise fall back to brute force
55        yield from _closest_upts_bf(etypes, eupts, pts)
56
57
58class SamplerPlugin(BasePlugin):
59    name = 'sampler'
60    systems = ['*']
61    formulations = ['dual', 'std']
62
63    def __init__(self, intg, cfgsect, suffix):
64        super().__init__(intg, cfgsect, suffix)
65
66        # Underlying elements class
67        self.elementscls = intg.system.elementscls
68
69        # Output frequency
70        self.nsteps = self.cfg.getint(cfgsect, 'nsteps')
71
72        # List of points to be sampled and format
73        self.pts = self.cfg.getliteral(cfgsect, 'samp-pts')
74        self.fmt = self.cfg.get(cfgsect, 'format', 'primitive')
75
76        # MPI info
77        comm, rank, root = get_comm_rank_root()
78
79        # MPI rank responsible for each point and rank-indexed info
80        self._ptsrank = ptsrank = []
81        self._ptsinfo = ptsinfo = [[] for i in range(comm.size)]
82
83        # Physical location of the solution points
84        plocs = [p.swapaxes(1, 2) for p in intg.system.ele_ploc_upts]
85
86        # Locate the closest solution points in our partition
87        closest = _closest_upts(intg.system.ele_types, plocs, self.pts)
88
89        # Process these points
90        for cp in closest:
91            # Reduce over the distance
92            _, mrank = comm.allreduce((cp[0], rank), op=get_mpi('minloc'))
93
94            # Store the rank responsible along with its info
95            ptsrank.append(mrank)
96            ptsinfo[mrank].append(
97                comm.bcast(cp[1:] if rank == mrank else None, root=mrank)
98            )
99
100        # If we're the root rank then open the output file
101        if rank == root:
102            self.outf = init_csv(self.cfg, cfgsect, self._header)
103
104    @property
105    def _header(self):
106        colnames = ['t'] + ['x', 'y', 'z'][:self.ndims]
107        colnames += ['prank', 'etype', 'uidx', 'eidx']
108
109        if self.fmt == 'primitive':
110            colnames += self.elementscls.privarmap[self.ndims]
111        else:
112            colnames += self.elementscls.convarmap[self.ndims]
113
114        return ','.join(colnames)
115
116    def _process_samples(self, samps):
117        samps = np.array(samps)
118
119        # If necessary then convert to primitive form
120        if self.fmt == 'primitive' and samps.size:
121            samps = self.elementscls.con_to_pri(samps.T, self.cfg)
122            samps = np.array(samps).T
123
124        return samps.tolist()
125
126    def __call__(self, intg):
127        # Return if no output is due
128        if intg.nacptsteps % self.nsteps:
129            return
130
131        # MPI info
132        comm, rank, root = get_comm_rank_root()
133
134        # Solution matrices indexed by element type
135        solns = dict(zip(intg.system.ele_types, intg.soln))
136
137        # Points we're responsible for sampling
138        ourpts = self._ptsinfo[comm.rank]
139
140        # Sample the solution matrices at these points
141        samples = [solns[et][ui, :, ei] for _, et, (ui, ei) in ourpts]
142        samples = self._process_samples(samples)
143
144        # Gather to the root rank to give a list of points per rank
145        samples = comm.gather(samples, root=root)
146
147        # If we're the root rank then output
148        if rank == root:
149            # Collate
150            iters = [zip(pi, sp) for pi, sp in zip(self._ptsinfo, samples)]
151
152            for mrank in self._ptsrank:
153                # Unpack
154                (ploc, etype, idx), samp = next(iters[mrank])
155
156                # Determine the physical mesh rank
157                prank = intg.rallocs.mprankmap[mrank]
158
159                # Write the output row
160                print(intg.tcurr, *ploc, prank, etype, *idx, *samp,
161                      sep=',', file=self.outf)
162
163            # Flush to disk
164            self.outf.flush()
165