1# -*- coding: utf-8 -*- 2 3import numpy as np 4 5from pyfr.mpiutil import get_comm_rank_root, get_mpi 6from pyfr.plugins.base import BasePlugin, init_csv 7 8 9def _closest_upts_bf(etypes, eupts, pts): 10 for p in pts: 11 # Compute the distances between each point and p 12 dists = [np.linalg.norm(e - p, axis=2) for e in eupts] 13 14 # Get the index of the closest point to p for each element type 15 amins = [np.unravel_index(np.argmin(d), d.shape) for d in dists] 16 17 # Dereference to get the actual distances and locations 18 dmins = [d[a] for d, a in zip(dists, amins)] 19 plocs = [e[a] for e, a in zip(eupts, amins)] 20 21 # Find the minimum across all element types 22 yield min(zip(dmins, plocs, etypes, amins)) 23 24 25def _closest_upts_kd(etypes, eupts, pts): 26 from scipy.spatial import cKDTree 27 28 # Flatten the physical location arrays 29 feupts = [e.reshape(-1, e.shape[-1]) for e in eupts] 30 31 # For each element type construct a KD-tree of the upt locations 32 trees = [cKDTree(f) for f in feupts] 33 34 for p in pts: 35 # Query the distance/index of the closest upt to p 36 dmins, amins = zip(*[t.query(p) for t in trees]) 37 38 # Unravel the indices 39 amins = [np.unravel_index(i, e.shape[:2]) 40 for i, e in zip(amins, eupts)] 41 42 # Dereference to obtain the precise locations 43 plocs = [e[a] for e, a in zip(eupts, amins)] 44 45 # Reduce across element types 46 yield min(zip(dmins, plocs, etypes, amins)) 47 48 49def _closest_upts(etypes, eupts, pts): 50 try: 51 # Attempt to use a KD-tree based approach 52 yield from _closest_upts_kd(etypes, eupts, pts) 53 except ImportError: 54 # Otherwise fall back to brute force 55 yield from _closest_upts_bf(etypes, eupts, pts) 56 57 58class SamplerPlugin(BasePlugin): 59 name = 'sampler' 60 systems = ['*'] 61 formulations = ['dual', 'std'] 62 63 def __init__(self, intg, cfgsect, suffix): 64 super().__init__(intg, cfgsect, suffix) 65 66 # Underlying elements class 67 self.elementscls = intg.system.elementscls 68 69 # Output frequency 70 self.nsteps = self.cfg.getint(cfgsect, 'nsteps') 71 72 # List of points to be sampled and format 73 self.pts = self.cfg.getliteral(cfgsect, 'samp-pts') 74 self.fmt = self.cfg.get(cfgsect, 'format', 'primitive') 75 76 # MPI info 77 comm, rank, root = get_comm_rank_root() 78 79 # MPI rank responsible for each point and rank-indexed info 80 self._ptsrank = ptsrank = [] 81 self._ptsinfo = ptsinfo = [[] for i in range(comm.size)] 82 83 # Physical location of the solution points 84 plocs = [p.swapaxes(1, 2) for p in intg.system.ele_ploc_upts] 85 86 # Locate the closest solution points in our partition 87 closest = _closest_upts(intg.system.ele_types, plocs, self.pts) 88 89 # Process these points 90 for cp in closest: 91 # Reduce over the distance 92 _, mrank = comm.allreduce((cp[0], rank), op=get_mpi('minloc')) 93 94 # Store the rank responsible along with its info 95 ptsrank.append(mrank) 96 ptsinfo[mrank].append( 97 comm.bcast(cp[1:] if rank == mrank else None, root=mrank) 98 ) 99 100 # If we're the root rank then open the output file 101 if rank == root: 102 self.outf = init_csv(self.cfg, cfgsect, self._header) 103 104 @property 105 def _header(self): 106 colnames = ['t'] + ['x', 'y', 'z'][:self.ndims] 107 colnames += ['prank', 'etype', 'uidx', 'eidx'] 108 109 if self.fmt == 'primitive': 110 colnames += self.elementscls.privarmap[self.ndims] 111 else: 112 colnames += self.elementscls.convarmap[self.ndims] 113 114 return ','.join(colnames) 115 116 def _process_samples(self, samps): 117 samps = np.array(samps) 118 119 # If necessary then convert to primitive form 120 if self.fmt == 'primitive' and samps.size: 121 samps = self.elementscls.con_to_pri(samps.T, self.cfg) 122 samps = np.array(samps).T 123 124 return samps.tolist() 125 126 def __call__(self, intg): 127 # Return if no output is due 128 if intg.nacptsteps % self.nsteps: 129 return 130 131 # MPI info 132 comm, rank, root = get_comm_rank_root() 133 134 # Solution matrices indexed by element type 135 solns = dict(zip(intg.system.ele_types, intg.soln)) 136 137 # Points we're responsible for sampling 138 ourpts = self._ptsinfo[comm.rank] 139 140 # Sample the solution matrices at these points 141 samples = [solns[et][ui, :, ei] for _, et, (ui, ei) in ourpts] 142 samples = self._process_samples(samples) 143 144 # Gather to the root rank to give a list of points per rank 145 samples = comm.gather(samples, root=root) 146 147 # If we're the root rank then output 148 if rank == root: 149 # Collate 150 iters = [zip(pi, sp) for pi, sp in zip(self._ptsinfo, samples)] 151 152 for mrank in self._ptsrank: 153 # Unpack 154 (ploc, etype, idx), samp = next(iters[mrank]) 155 156 # Determine the physical mesh rank 157 prank = intg.rallocs.mprankmap[mrank] 158 159 # Write the output row 160 print(intg.tcurr, *ploc, prank, etype, *idx, *samp, 161 sep=',', file=self.outf) 162 163 # Flush to disk 164 self.outf.flush() 165