1# -*- coding: utf-8 -*- 2 3from collections import Counter, defaultdict, namedtuple 4import itertools as it 5import re 6import uuid 7 8import numpy as np 9 10from pyfr.inifile import Inifile 11 12 13Graph = namedtuple('Graph', ['vtab', 'etab', 'vwts', 'ewts']) 14 15 16class BasePartitioner(object): 17 def __init__(self, partwts, elewts, nsubeles=64, opts={}): 18 self.partwts = partwts 19 self.elewts = elewts 20 self.nparts = len(partwts) 21 self.nsubeles = nsubeles 22 23 # Parse the options list 24 self.opts = {} 25 for k, v in dict(self.dflt_opts, **opts).items(): 26 if k in self.int_opts: 27 self.opts[k] = int(v) 28 elif k in self.enum_opts: 29 self.opts[k] = self.enum_opts[k][v] 30 else: 31 raise ValueError('Invalid partitioner option') 32 33 def _combine_mesh_parts(self, mesh): 34 # Get the per-partition element counts 35 pinf = mesh.partition_info('spt') 36 37 # Shape points, linear flags, element offsets, and remapping table 38 spts = defaultdict(list) 39 linf = defaultdict(list) 40 offs = defaultdict(dict) 41 rnum = defaultdict(dict) 42 43 for en, pn in pinf.items(): 44 for i, n in enumerate(pn): 45 if n > 0: 46 offs[en][i] = off = sum(s.shape[1] for s in spts[en]) 47 spts[en].append(mesh[f'spt_{en}_p{i}']) 48 linf[en].append(mesh[f'spt_{en}_p{i}', 'linear']) 49 rnum[en].update(((i, j), (0, off + j)) for j in range(n)) 50 51 def offset_con(con, pr): 52 con = con.copy().astype('U4,i4,i1,i2') 53 54 for en, pn in pinf.items(): 55 if pn[pr] > 0: 56 con['f1'][np.where(con['f0'] == en)] += offs[en][pr] 57 58 return con 59 60 # Connectivity 61 intcon, mpicon, bccon = [], {}, defaultdict(list) 62 63 for f in mesh: 64 if (mi := re.match(r'con_p(\d+)$', f)): 65 intcon.append(offset_con(mesh[f], int(mi.group(1)))) 66 elif (mm := re.match(r'con_p(\d+)p(\d+)$', f)): 67 l, r = int(mm.group(1)), int(mm.group(2)) 68 lcon = offset_con(mesh[f], l) 69 70 if (r, l) in mpicon: 71 rcon = mpicon.pop((r, l)) 72 intcon.append(np.vstack([lcon, rcon])) 73 else: 74 mpicon[l, r] = lcon 75 elif (bc := re.match(r'bcon_(.+?)_p(\d+)$', f)): 76 name, l = bc.group(1), int(bc.group(2)) 77 bccon[name].append(offset_con(mesh[f], l)) 78 79 # Output data type 80 dtype = 'U4,i4,i1,i2' 81 82 # Concatenate these arrays to from the new mesh 83 newmesh = {'con_p0': np.hstack(intcon).astype(dtype)} 84 85 for en in spts: 86 newmesh[f'spt_{en}_p0'] = np.hstack(spts[en]) 87 newmesh[f'spt_{en}_p0', 'linear'] = np.hstack(linf[en]) 88 89 for k, v in bccon.items(): 90 newmesh[f'bcon_{k}_p0'] = np.hstack(v).astype(dtype) 91 92 return newmesh, rnum 93 94 def _combine_soln_parts(self, soln, prefix): 95 newsoln = defaultdict(list) 96 97 for f, (en, shape) in soln.array_info(prefix).items(): 98 newsoln[f'{prefix}_{en}_p0'].append(soln[f]) 99 100 return {k: np.dstack(v) for k, v in newsoln.items()} 101 102 def _construct_graph(self, con): 103 # Edges of the dual graph 104 con = con[['f0', 'f1']] 105 con = np.hstack([con, con[::-1]]) 106 107 # Sort by the left hand side 108 idx = np.lexsort([con['f0'][0], con['f1'][0]]) 109 con = con[:, idx] 110 111 # Left and right hand side element types/indicies 112 lhs, rhs = con 113 114 # Compute vertex offsets 115 vtab = np.where(lhs[1:] != lhs[:-1])[0] 116 vtab = np.concatenate(([0], vtab + 1, [len(lhs)])) 117 118 # Compute the element type/index to vertex number map 119 vetimap = lhs[vtab[:-1]].tolist() 120 etivmap = {k: v for v, k in enumerate(vetimap)} 121 122 # Prepare the list of edges for each vertex 123 etab = np.array([etivmap[r] for r in rhs.tolist()]) 124 125 # Prepare the list of vertex and edge weights 126 vwts = np.array([self.elewts[t] for t, i in vetimap]) 127 ewts = np.ones_like(etab) 128 129 return Graph(vtab, etab, vwts, ewts), vetimap 130 131 def _partition_graph(self, graph, partwts): 132 pass 133 134 def _renumber_verts(self, mesh, vetimap, vparts): 135 pscon = [[] for i in range(self.nparts)] 136 vpartmap, bndeti = dict(zip(vetimap, vparts)), set() 137 138 # Construct per-partition connectivity arrays and tag elements 139 # which are on partition boundaries 140 for l, r in zip(*mesh['con_p0'][['f0', 'f1']].tolist()): 141 if vpartmap[l] == vpartmap[r]: 142 pscon[vpartmap[l]].append([l, r]) 143 else: 144 pscon[vpartmap[l]].append([l, r]) 145 pscon[vpartmap[r]].append([l, r]) 146 bndeti |= {l, r} 147 148 # Start by assigning the lowest numbers to these boundary elements 149 nvetimap, nvparts = list(bndeti), [vpartmap[eti] for eti in bndeti] 150 151 # Use sub-partitioning to assign interior element numbers 152 for part, scon in enumerate(pscon): 153 # Construct a graph for this partition 154 scon = np.array(scon, dtype='U4,i4').T 155 sgraph, svetimap = self._construct_graph(scon) 156 157 # Determine the number of sub-partitions 158 nsp = len(svetimap) // self.nsubeles + 1 159 160 # Partition the graph 161 if nsp == 1: 162 svparts = [0]*len(svetimap) 163 else: 164 svparts = self._partition_graph(sgraph, [1]*nsp) 165 166 # Group elements according to their type (linear vs curved) 167 # and sub-partition number 168 linsvetimap = [[] for i in range(nsp)] 169 cursvetimap = [[] for i in range(nsp)] 170 for (etype, eidx), spart in zip(svetimap, svparts): 171 if (etype, eidx) in bndeti: 172 continue 173 174 if mesh[f'spt_{etype}_p0', 'linear'][eidx]: 175 linsvetimap[spart].append((etype, eidx)) 176 else: 177 cursvetimap[spart].append((etype, eidx)) 178 179 # Append to the global list 180 nvetimap.extend(it.chain(*cursvetimap, *linsvetimap)) 181 nvparts.extend([part]*sum(map(len, cursvetimap + linsvetimap))) 182 183 return nvetimap, nvparts 184 185 def _partition_spts(self, mesh, vetimap, vparts): 186 spt_px = defaultdict(list) 187 lin_px = defaultdict(list) 188 189 for (etype, eidxg), part in zip(vetimap, vparts): 190 f = f'spt_{etype}_p0' 191 192 spt_px[etype, part].append(mesh[f][:, eidxg, :]) 193 lin_px[etype, part].append(mesh[f, 'linear'][eidxg]) 194 195 newmesh = {} 196 for etype, pn in spt_px: 197 f = f'spt_{etype}_p{pn}' 198 199 newmesh[f] = np.array(spt_px[etype, pn]).swapaxes(0, 1) 200 newmesh[f, 'linear'] = np.array(lin_px[etype, pn]) 201 202 return newmesh 203 204 def _partition_soln(self, soln, prefix, vetimap, vparts): 205 soln_px = defaultdict(list) 206 for (etype, eidxg), part in zip(vetimap, vparts): 207 f = f'{prefix}_{etype}_p0' 208 209 soln_px[etype, part].append(soln[f][..., eidxg]) 210 211 return {f'{prefix}_{etype}_p{pn}': np.dstack(v) 212 for (etype, pn), v in soln_px.items()} 213 214 def _partition_con(self, mesh, vetimap, vparts): 215 con_px = defaultdict(list) 216 con_pxpy = defaultdict(list) 217 bcon_px = defaultdict(list) 218 219 # Global-to-local element index map 220 eleglmap = {} 221 pcounter = Counter() 222 223 for (etype, eidxg), part in zip(vetimap, vparts): 224 eleglmap[etype, eidxg] = (part, pcounter[etype, part]) 225 pcounter[etype, part] += 1 226 227 # Generate the face connectivity 228 for l, r in zip(*mesh['con_p0'].tolist()): 229 letype, leidxg, lfidx, lflags = l 230 retype, reidxg, rfidx, rflags = r 231 232 lpart, leidxl = eleglmap[letype, leidxg] 233 rpart, reidxl = eleglmap[retype, reidxg] 234 235 conl = (letype, leidxl, lfidx, lflags) 236 conr = (retype, reidxl, rfidx, rflags) 237 238 if lpart == rpart: 239 con_px[lpart].append([conl, conr]) 240 else: 241 con_pxpy[lpart, rpart].append(conl) 242 con_pxpy[rpart, lpart].append(conr) 243 244 # Generate boundary conditions 245 for f in filter(lambda f: isinstance(f, str), mesh): 246 if (m := re.match('bcon_(.+?)_p0$', f)): 247 for lpetype, leidxg, lfidx, lflags in mesh[f].tolist(): 248 lpart, leidxl = eleglmap[lpetype, leidxg] 249 conl = (lpetype, leidxl, lfidx, lflags) 250 251 bcon_px[m.group(1), lpart].append(conl) 252 253 # Output data type 254 dtype = 'S4,i4,i1,i2' 255 256 # Output 257 con = {} 258 259 for px, v in con_px.items(): 260 con[f'con_p{px}'] = np.array(v, dtype=dtype).T 261 262 for (px, py), v in con_pxpy.items(): 263 con[f'con_p{px}p{py}'] = np.array(v, dtype=dtype) 264 265 for (etype, px), v in bcon_px.items(): 266 con[f'bcon_{etype}_p{px}'] = np.array(v, dtype=dtype) 267 268 return con, eleglmap 269 270 def partition(self, mesh): 271 # Extract the current UUID from the mesh 272 curruuid = mesh['mesh_uuid'] 273 274 # Combine any pre-existing partitions 275 mesh, rnum = self._combine_mesh_parts(mesh) 276 277 # Obtain the dual graph for this mesh 278 graph, vetimap = self._construct_graph(mesh['con_p0']) 279 280 # Partition the graph 281 if self.nparts > 1: 282 vparts = self._partition_graph(graph, self.partwts).tolist() 283 284 if (n := len(set(vparts))) != self.nparts: 285 raise RuntimeError(f'Partitioner error: mesh has {n} parts ' 286 f'versus goal of {self.nparts}') 287 else: 288 vparts = [0]*len(vetimap) 289 290 # Renumber vertices 291 vetimap, vparts = self._renumber_verts(mesh, vetimap, vparts) 292 293 # Partition the connectivity portion of the mesh 294 newmesh, eleglmap = self._partition_con(mesh, vetimap, vparts) 295 296 # Handle the shape points 297 newmesh.update(self._partition_spts(mesh, vetimap, vparts)) 298 299 # Update the renumbering table 300 for etype, emap in rnum.items(): 301 for k, (pidx, eidx) in emap.items(): 302 emap[k] = eleglmap[etype, eidx] 303 304 # Generate a new UUID for the mesh 305 newmesh['mesh_uuid'] = newuuid = str(uuid.uuid4()) 306 307 # Build the solution converter 308 def partition_soln(soln): 309 # Check the UUID 310 if curruuid != soln['mesh_uuid']: 311 raise ValueError('Mismatched solution/mesh') 312 313 # Obtain the prefix 314 prefix = Inifile(soln['stats']).get('data', 'prefix') 315 316 # Combine and repartition the solution 317 newsoln = self._combine_soln_parts(soln, prefix) 318 newsoln = self._partition_soln(newsoln, prefix, vetimap, vparts) 319 320 # Copy over the metadata 321 for f in soln: 322 if re.match('stats|config|plugins', f): 323 newsoln[f] = soln[f] 324 325 # Apply the new UUID 326 newsoln['mesh_uuid'] = newuuid 327 328 return newsoln 329 330 return newmesh, rnum, partition_soln 331