1# -*- coding: utf-8 -*-
2
3from collections import Counter, defaultdict, namedtuple
4import itertools as it
5import re
6import uuid
7
8import numpy as np
9
10from pyfr.inifile import Inifile
11
12
13Graph = namedtuple('Graph', ['vtab', 'etab', 'vwts', 'ewts'])
14
15
16class BasePartitioner(object):
17    def __init__(self, partwts, elewts, nsubeles=64, opts={}):
18        self.partwts = partwts
19        self.elewts = elewts
20        self.nparts = len(partwts)
21        self.nsubeles = nsubeles
22
23        # Parse the options list
24        self.opts = {}
25        for k, v in dict(self.dflt_opts, **opts).items():
26            if k in self.int_opts:
27                self.opts[k] = int(v)
28            elif k in self.enum_opts:
29                self.opts[k] = self.enum_opts[k][v]
30            else:
31                raise ValueError('Invalid partitioner option')
32
33    def _combine_mesh_parts(self, mesh):
34        # Get the per-partition element counts
35        pinf = mesh.partition_info('spt')
36
37        # Shape points, linear flags, element offsets, and remapping table
38        spts = defaultdict(list)
39        linf = defaultdict(list)
40        offs = defaultdict(dict)
41        rnum = defaultdict(dict)
42
43        for en, pn in pinf.items():
44            for i, n in enumerate(pn):
45                if n > 0:
46                    offs[en][i] = off = sum(s.shape[1] for s in spts[en])
47                    spts[en].append(mesh[f'spt_{en}_p{i}'])
48                    linf[en].append(mesh[f'spt_{en}_p{i}', 'linear'])
49                    rnum[en].update(((i, j), (0, off + j)) for j in range(n))
50
51        def offset_con(con, pr):
52            con = con.copy().astype('U4,i4,i1,i2')
53
54            for en, pn in pinf.items():
55                if pn[pr] > 0:
56                    con['f1'][np.where(con['f0'] == en)] += offs[en][pr]
57
58            return con
59
60        # Connectivity
61        intcon, mpicon, bccon = [], {}, defaultdict(list)
62
63        for f in mesh:
64            if (mi := re.match(r'con_p(\d+)$', f)):
65                intcon.append(offset_con(mesh[f], int(mi.group(1))))
66            elif (mm := re.match(r'con_p(\d+)p(\d+)$', f)):
67                l, r = int(mm.group(1)), int(mm.group(2))
68                lcon = offset_con(mesh[f], l)
69
70                if (r, l) in mpicon:
71                    rcon = mpicon.pop((r, l))
72                    intcon.append(np.vstack([lcon, rcon]))
73                else:
74                    mpicon[l, r] = lcon
75            elif (bc := re.match(r'bcon_(.+?)_p(\d+)$', f)):
76                name, l = bc.group(1), int(bc.group(2))
77                bccon[name].append(offset_con(mesh[f], l))
78
79        # Output data type
80        dtype = 'U4,i4,i1,i2'
81
82        # Concatenate these arrays to from the new mesh
83        newmesh = {'con_p0': np.hstack(intcon).astype(dtype)}
84
85        for en in spts:
86            newmesh[f'spt_{en}_p0'] = np.hstack(spts[en])
87            newmesh[f'spt_{en}_p0', 'linear'] = np.hstack(linf[en])
88
89        for k, v in bccon.items():
90            newmesh[f'bcon_{k}_p0'] = np.hstack(v).astype(dtype)
91
92        return newmesh, rnum
93
94    def _combine_soln_parts(self, soln, prefix):
95        newsoln = defaultdict(list)
96
97        for f, (en, shape) in soln.array_info(prefix).items():
98            newsoln[f'{prefix}_{en}_p0'].append(soln[f])
99
100        return {k: np.dstack(v) for k, v in newsoln.items()}
101
102    def _construct_graph(self, con):
103        # Edges of the dual graph
104        con = con[['f0', 'f1']]
105        con = np.hstack([con, con[::-1]])
106
107        # Sort by the left hand side
108        idx = np.lexsort([con['f0'][0], con['f1'][0]])
109        con = con[:, idx]
110
111        # Left and right hand side element types/indicies
112        lhs, rhs = con
113
114        # Compute vertex offsets
115        vtab = np.where(lhs[1:] != lhs[:-1])[0]
116        vtab = np.concatenate(([0], vtab + 1, [len(lhs)]))
117
118        # Compute the element type/index to vertex number map
119        vetimap = lhs[vtab[:-1]].tolist()
120        etivmap = {k: v for v, k in enumerate(vetimap)}
121
122        # Prepare the list of edges for each vertex
123        etab = np.array([etivmap[r] for r in rhs.tolist()])
124
125        # Prepare the list of vertex and edge weights
126        vwts = np.array([self.elewts[t] for t, i in vetimap])
127        ewts = np.ones_like(etab)
128
129        return Graph(vtab, etab, vwts, ewts), vetimap
130
131    def _partition_graph(self, graph, partwts):
132        pass
133
134    def _renumber_verts(self, mesh, vetimap, vparts):
135        pscon = [[] for i in range(self.nparts)]
136        vpartmap, bndeti = dict(zip(vetimap, vparts)), set()
137
138        # Construct per-partition connectivity arrays and tag elements
139        # which are on partition boundaries
140        for l, r in zip(*mesh['con_p0'][['f0', 'f1']].tolist()):
141            if vpartmap[l] == vpartmap[r]:
142                pscon[vpartmap[l]].append([l, r])
143            else:
144                pscon[vpartmap[l]].append([l, r])
145                pscon[vpartmap[r]].append([l, r])
146                bndeti |= {l, r}
147
148        # Start by assigning the lowest numbers to these boundary elements
149        nvetimap, nvparts = list(bndeti), [vpartmap[eti] for eti in bndeti]
150
151        # Use sub-partitioning to assign interior element numbers
152        for part, scon in enumerate(pscon):
153            # Construct a graph for this partition
154            scon = np.array(scon, dtype='U4,i4').T
155            sgraph, svetimap = self._construct_graph(scon)
156
157            # Determine the number of sub-partitions
158            nsp = len(svetimap) // self.nsubeles + 1
159
160            # Partition the graph
161            if nsp == 1:
162                svparts = [0]*len(svetimap)
163            else:
164                svparts = self._partition_graph(sgraph, [1]*nsp)
165
166            # Group elements according to their type (linear vs curved)
167            # and sub-partition number
168            linsvetimap = [[] for i in range(nsp)]
169            cursvetimap = [[] for i in range(nsp)]
170            for (etype, eidx), spart in zip(svetimap, svparts):
171                if (etype, eidx) in bndeti:
172                    continue
173
174                if mesh[f'spt_{etype}_p0', 'linear'][eidx]:
175                    linsvetimap[spart].append((etype, eidx))
176                else:
177                    cursvetimap[spart].append((etype, eidx))
178
179            # Append to the global list
180            nvetimap.extend(it.chain(*cursvetimap, *linsvetimap))
181            nvparts.extend([part]*sum(map(len, cursvetimap + linsvetimap)))
182
183        return nvetimap, nvparts
184
185    def _partition_spts(self, mesh, vetimap, vparts):
186        spt_px = defaultdict(list)
187        lin_px = defaultdict(list)
188
189        for (etype, eidxg), part in zip(vetimap, vparts):
190            f = f'spt_{etype}_p0'
191
192            spt_px[etype, part].append(mesh[f][:, eidxg, :])
193            lin_px[etype, part].append(mesh[f, 'linear'][eidxg])
194
195        newmesh = {}
196        for etype, pn in spt_px:
197            f = f'spt_{etype}_p{pn}'
198
199            newmesh[f] = np.array(spt_px[etype, pn]).swapaxes(0, 1)
200            newmesh[f, 'linear'] = np.array(lin_px[etype, pn])
201
202        return newmesh
203
204    def _partition_soln(self, soln, prefix, vetimap, vparts):
205        soln_px = defaultdict(list)
206        for (etype, eidxg), part in zip(vetimap, vparts):
207            f = f'{prefix}_{etype}_p0'
208
209            soln_px[etype, part].append(soln[f][..., eidxg])
210
211        return {f'{prefix}_{etype}_p{pn}': np.dstack(v)
212                for (etype, pn), v in soln_px.items()}
213
214    def _partition_con(self, mesh, vetimap, vparts):
215        con_px = defaultdict(list)
216        con_pxpy = defaultdict(list)
217        bcon_px = defaultdict(list)
218
219        # Global-to-local element index map
220        eleglmap = {}
221        pcounter = Counter()
222
223        for (etype, eidxg), part in zip(vetimap, vparts):
224            eleglmap[etype, eidxg] = (part, pcounter[etype, part])
225            pcounter[etype, part] += 1
226
227        # Generate the face connectivity
228        for l, r in zip(*mesh['con_p0'].tolist()):
229            letype, leidxg, lfidx, lflags = l
230            retype, reidxg, rfidx, rflags = r
231
232            lpart, leidxl = eleglmap[letype, leidxg]
233            rpart, reidxl = eleglmap[retype, reidxg]
234
235            conl = (letype, leidxl, lfidx, lflags)
236            conr = (retype, reidxl, rfidx, rflags)
237
238            if lpart == rpart:
239                con_px[lpart].append([conl, conr])
240            else:
241                con_pxpy[lpart, rpart].append(conl)
242                con_pxpy[rpart, lpart].append(conr)
243
244        # Generate boundary conditions
245        for f in filter(lambda f: isinstance(f, str), mesh):
246            if (m := re.match('bcon_(.+?)_p0$', f)):
247                for lpetype, leidxg, lfidx, lflags in mesh[f].tolist():
248                    lpart, leidxl = eleglmap[lpetype, leidxg]
249                    conl = (lpetype, leidxl, lfidx, lflags)
250
251                    bcon_px[m.group(1), lpart].append(conl)
252
253        # Output data type
254        dtype = 'S4,i4,i1,i2'
255
256        # Output
257        con = {}
258
259        for px, v in con_px.items():
260            con[f'con_p{px}'] = np.array(v, dtype=dtype).T
261
262        for (px, py), v in con_pxpy.items():
263            con[f'con_p{px}p{py}'] = np.array(v, dtype=dtype)
264
265        for (etype, px), v in bcon_px.items():
266            con[f'bcon_{etype}_p{px}'] = np.array(v, dtype=dtype)
267
268        return con, eleglmap
269
270    def partition(self, mesh):
271        # Extract the current UUID from the mesh
272        curruuid = mesh['mesh_uuid']
273
274        # Combine any pre-existing partitions
275        mesh, rnum = self._combine_mesh_parts(mesh)
276
277        # Obtain the dual graph for this mesh
278        graph, vetimap = self._construct_graph(mesh['con_p0'])
279
280        # Partition the graph
281        if self.nparts > 1:
282            vparts = self._partition_graph(graph, self.partwts).tolist()
283
284            if (n := len(set(vparts))) != self.nparts:
285                raise RuntimeError(f'Partitioner error: mesh has {n} parts '
286                                   f'versus goal of {self.nparts}')
287        else:
288            vparts = [0]*len(vetimap)
289
290        # Renumber vertices
291        vetimap, vparts = self._renumber_verts(mesh, vetimap, vparts)
292
293        # Partition the connectivity portion of the mesh
294        newmesh, eleglmap = self._partition_con(mesh, vetimap, vparts)
295
296        # Handle the shape points
297        newmesh.update(self._partition_spts(mesh, vetimap, vparts))
298
299        # Update the renumbering table
300        for etype, emap in rnum.items():
301            for k, (pidx, eidx) in emap.items():
302                emap[k] = eleglmap[etype, eidx]
303
304        # Generate a new UUID for the mesh
305        newmesh['mesh_uuid'] = newuuid = str(uuid.uuid4())
306
307        # Build the solution converter
308        def partition_soln(soln):
309            # Check the UUID
310            if curruuid != soln['mesh_uuid']:
311                raise ValueError('Mismatched solution/mesh')
312
313            # Obtain the prefix
314            prefix = Inifile(soln['stats']).get('data', 'prefix')
315
316            # Combine and repartition the solution
317            newsoln = self._combine_soln_parts(soln, prefix)
318            newsoln = self._partition_soln(newsoln, prefix, vetimap, vparts)
319
320            # Copy over the metadata
321            for f in soln:
322                if re.match('stats|config|plugins', f):
323                    newsoln[f] = soln[f]
324
325            # Apply the new UUID
326            newsoln['mesh_uuid'] = newuuid
327
328            return newsoln
329
330        return newmesh, rnum, partition_soln
331