1import functools
2
3from guppy.etc.Descriptor import property_nondata
4
5
6class R_NORELATION:
7    code = -1
8    r = None
9
10    def stra(self, a, safe=True):
11        return '%s.??' % a
12
13
14class R_IDENTITY:
15    code = 0
16
17    def stra(self, a, safe=True):
18        return a
19
20
21class R_ATTRIBUTE:
22    code = 1
23    strpat = '%s.%s'
24
25
26class R_INDEXVAL:
27    code = 2
28
29    def stra(self, a, safe=True):
30        if safe:
31            return '%s[%s]' % (a, self.saferepr(self.r))
32        else:
33            return '%s[%r]' % (a, self.r)
34
35
36class R_INDEXKEY:
37    code = 3
38    strpat = '%s.keys()[%r]'
39
40
41class R_INTERATTR:
42    code = 4
43    strpat = '%s->%s'
44
45
46class R_HASATTR:
47    code = 5
48    strpat = '%s.__dict__.keys()[%r]'
49
50
51class R_LOCAL_VAR:
52    code = 6
53    strpat = '%s.f_locals[%r]'
54
55
56class R_CELL:
57    code = 7
58    strpat = '%s.f_locals [%r]'
59
60
61class R_STACK:
62    code = 8
63    strpat = '%s->f_valuestack[%d]'
64
65
66class R_INSET:
67    code = 9
68    strpat = 'list(%s)[%d]'
69
70
71class R_RELSRC:
72    code = 10
73
74    def stra(self, a, safe=True):
75        return self.r % (a,)
76
77
78class R_LIMIT:
79    code = 11
80
81
82@functools.total_ordering
83class RelationBase(object):
84    __slots__ = 'r', 'isinverted'
85
86    def __init__(self, r, isinverted=0):
87        self.r = r
88        self.isinverted = isinverted
89
90    def __lt__(self, other):
91        if isinstance(other, RelationBase):
92            if self.code != other.code:
93                return self.code < other.code
94            return self.r < other.r
95        else:
96            return id(type(self)) < id(type(other))
97
98    def __eq__(self, other):
99        if isinstance(other, RelationBase):
100            if self.code != other.code:
101                return False
102            return self.r == other.r
103        else:
104            return False
105
106    def __str__(self):
107        return self.stra('%s')
108
109    def inverted(self):
110        return self.__class__(self.r, not self.isinverted)
111
112    def stra(self, a, safe=True):
113        return self.strpat % (a, self.r)
114
115
116class MultiRelation(RelationBase):
117    def __init__(self, rels):
118        self.rels = rels
119
120    def stra(self, a, safe=True):
121        return '<'+','.join([x.stra(a, safe=safe) for x in self.rels])+'>'
122
123
124@functools.total_ordering
125class Path:
126    def __init__(self, mod, path, index, srcname):
127        self.mod = mod
128        self.path = path[1:]
129        self.index = index
130        self.src = path[1]
131        self.tgt = path[-1]
132        self.strprefix = '%s'
133        if srcname == '_str_of_src_':
134            srcname = self.src.brief
135        if callable(srcname):
136            srcname = srcname(self)
137        self.srcname = srcname
138
139    def __lt__(self, other):
140        return str(self) < str(other)
141
142    def __eq__(self, other):
143        return str(self) == str(other)
144
145    def __len__(self):
146        return int((len(self.path) - 1) / 2)
147
148    def stra(self, safe=True):
149        if self.path:
150            s = self.strprefix
151            for i in range(1, len(self.path), 2):
152                r = self.path[i]
153                s = r.stra(s, safe=safe)
154        else:
155            s = '<Empty Path>'
156        return s
157
158    def __str__(self):
159        return self.stra()
160
161    def __repr__(self):
162        return repr(str(self))
163
164    def _get_line_iter(self):
165        yield '%2d: %s' % (self.index, str(self) % self.srcname)
166
167    def types(self):
168        return [type(x) for x in self.path]
169
170
171class PathsIter:
172    def __init__(self, paths, start=None, stop=None):
173        self.paths = paths
174        self.mod = paths.mod
175        self.stop = stop
176        self.reset(start)
177
178    def __iter__(self):
179        return self
180
181    def reset(self, idx=None):
182        if idx is None:
183            idx = 0
184        if idx != 0:  # Optimization: don't calculate numpaths in common case.
185            ln = self.paths.numpaths
186            if idx < 0:
187                idx = ln + idx
188            if not (0 <= idx < ln):
189                self.isatend = 1
190                return
191        Src = self.paths.Src
192        sr = [('%s', src.by(Src.er)) for src in Src.byid.parts]
193        srs = []
194        idxs = []
195        np = 0
196        while sr:
197            if idx == 0:
198                i, (rel, src) = 0, sr[0]
199            else:
200                for i, (rel, src) in enumerate(sr):
201                    npnext = np + self.paths.numpaths_from(src)
202                    if idx < npnext:
203                        break
204                    np = npnext
205                else:
206                    assert 0
207            idxs.append(i)
208            srs.append(sr)
209            sr = self.mod.sortedrels(self.paths.IG, src)
210        self.pos = idx
211        self.idxs = idxs
212        self.srs = srs
213        self.isatend = not idxs
214
215    def __next__(self):
216        paths = self.paths
217        if (self.isatend or
218                self.stop is not None and self.pos >= self.stop):
219            raise StopIteration
220        path = []
221        for row, col in enumerate(self.idxs):
222            sr = self.srs[row]
223            if sr is None:
224                sr = self.mod.sortedrels(paths.IG, path[-1])
225                self.srs[row] = sr
226            rel, dst = sr[col]
227            path.append(rel)
228            path.append(dst)
229        rp = self.mod.Path(paths, path, self.pos, paths.srcname)
230        self.pos += 1
231        while row >= 0:
232            self.idxs[row] += 1
233            if self.idxs[row] < len(self.srs[row]):
234                break
235            if row > 0:
236                self.srs[row] = None
237            self.idxs[row] = 0
238            row -= 1
239        else:
240            self.isatend = 1
241            self.pos = 0
242        return rp
243
244
245class ShortestPaths:
246    def __init__(self, sg, Dst):
247        self.sg = sg
248        self.Dst = Dst
249        self.mod = mod = sg.mod
250        self._hiding_tag_ = mod._hiding_tag_
251        self.srcname = sg.srcname
252        self.top = self
253
254        self.IG = IG = mod.nodegraph()
255        Edges = []
256        Y = Dst.nodes
257        while Y:
258            R = sg.G.domain_restricted(Y)
259            R.invert()
260            IG.update(R)
261            Edges.append(R)
262            Y = R.get_domain()
263        if Edges:
264            Edges.pop()
265            Edges.reverse()
266            self.Src = mod.idset(Edges[0].get_domain())
267        else:
268            self.Src = mod.iso()
269
270        self.edges = tuple(Edges)
271        sets = []
272        for i, e in enumerate(Edges):
273            if i == 0:
274                sets.append(mod.idset(e.get_domain()))
275            sets.append(mod.idset(e.get_range()))
276        self.sets = tuple(sets)
277
278        mod.OutputHandling.setup_printing(self)
279
280        self.maxpaths = 10
281
282    def __getitem__(self, idx):
283        try:
284            return next(self.iter(start=idx))
285        except StopIteration:
286            raise IndexError
287
288    def __iter__(self):
289        return self.iter()
290
291    def iter(self, start=0, stop=None):
292        return PathsIter(self, start, stop)
293
294    def aslist(self):
295        return list(self)
296
297    def copy_but_avoid_edges_at_levels(self, *args):
298        avoid = self.edges_at(*args).updated(self.sg.AvoidEdges)
299        assert avoid._hiding_tag_ is self.mod._hiding_tag_
300        return self.mod.shpaths(self.Dst, self.Src, avoid_edges=avoid)
301        # return self.mod.shpaths(self.dst, self.src, avoid_edges=avoid)
302
303    avoided = copy_but_avoid_edges_at_levels
304
305    # The builtin __len__ doesn't always work due to builtin Python restriction to int result:
306    # so we don't provide it at all to avoid unsuspected errors sometimes.
307    # Use .numpaths attribute instead.
308    #   def __len__(self):
309    #      return self.numpaths
310
311    def depth(self):
312        pass
313
314    def edges_at(self, *args):
315        E = self.mod.nodegraph()
316        for col in args:
317            E.update(self.edges[col])
318        assert E._hiding_tag_ == self.mod._parent.View._hiding_tag_
319        return E
320
321    def numpaths_from(self, Src):
322        try:
323            NP = self.NP
324        except AttributeError:
325            NP = self.mod.nodegraph(is_mapping=True)
326            NP.add_edges_n1(self.IG.get_domain(), None)
327            for dst in self.Dst.nodes:
328                NP.add_edge(dst, 1)
329            self.NP = NP
330        numedges = self.mod.hv.numedges
331        IG = self.IG
332
333        def np(y):
334            n = NP[y]
335            if n is None:
336                n = 0
337                for z in IG[y]:
338                    sn = NP[z]
339                    if sn is None:
340                        sn = np(z)
341                    n += sn * numedges(y, z)
342                NP[y] = n
343            return n
344        num = 0
345        for src in Src.nodes:
346            num += np(src)
347        return num
348
349    def _get_numpaths(self):
350        num = self.numpaths_from(self.Src)
351        self.numpaths = num
352        return num
353
354    numpaths = property_nondata(fget=_get_numpaths)
355
356    @property
357    def maxpaths(self):
358        return self.printer.max_more_lines
359
360    @maxpaths.setter
361    def maxpaths(self, value):
362        self.printer.max_more_lines = value
363
364    def _oh_get_num_lines(self):
365        return self.numpaths
366
367    def _oh_get_line_iter(self):
368        for el in self:
369            yield from el._get_line_iter()
370
371    def _oh_get_more_msg(self, start_lineno, end_lineno):
372        nummore = self.numpaths-(end_lineno+1)
373        return '<... %d more paths ...>' % nummore
374
375    def _oh_get_empty_msg(self):
376        if self.numpaths:
377            return '<No more paths>'
378        return None
379
380
381class ShortestGraph:
382    def __init__(self, mod, G, DstSets, Src, AvoidEdges,
383                 srcname=None, dstname=None):
384        self.mod = mod
385        self.G = G
386        self.Src = Src
387        self.DstSets = DstSets
388        self.AvoidEdges = AvoidEdges
389        if srcname is None:
390            if Src.count == 1:
391                srcname = mod.srcname_1
392            else:
393                srcname = mod.srcname_n
394        self.srcname = srcname
395        if dstname is None:
396            dstname = mod.dstname
397        self.dstname = dstname
398
399    def __getitem__(self, idx):
400        return self.mod.ShortestPaths(self, self.DstSets[idx])
401
402    def __len__(self):
403        return len(self.DstSets)
404
405    def __repr__(self):
406        lst = []
407        for i, p in enumerate(self):
408            lst.append('--- %s[%d] ---' % (self.dstname, i))
409            lst.append(str(p))
410
411        return '\n'.join(lst)
412
413
414class _GLUECLAMP_:
415    _preload_ = ('_hiding_tag_',)
416    _chgable_ = ('output', 'srcname_1', 'srcname_n')
417
418    srcname_1 = 'Src'
419    srcname_n = '_str_of_src_'
420    dstname = 'Dst'
421
422    _imports_ = (
423        '_parent.ImpSet:mutnodeset',
424        '_parent:OutputHandling',
425        '_parent.Use:idset',
426        '_parent.Use:iso',
427        '_parent.Use:Nothing',
428        '_parent.Use:reprefix',
429        '_parent.UniSet:idset_adapt',
430        '_parent.View:hv',
431        '_parent.View:nodegraph',
432        '_parent:View',  # NOT View.root, since it may change
433    )
434
435    def _get_rel_table(self):
436        table = {}
437        for name in dir(self._module):
438            if name.startswith('R_'):
439                c = getattr(self, name)
440
441                class r(c, self.RelationBase):
442                    repr = self.saferepr
443                    saferepr = self.saferepr
444                r.__qualname__ = r.__name__ = 'Based_'+name
445                table[c.code] = r
446        return table
447
448    def _get__hiding_tag_(self): return self._parent.View._hiding_tag_
449    def _get_identity(self): return self.rel_table[R_IDENTITY.code]('')
450    def _get_norelation(self): return self.rel_table[R_NORELATION.code]('')
451    def _get_saferepr(self): return self._root.reprlib.repr
452    def _get_shpathstep(self): return self.hv.shpathstep
453
454    def sortedrels(self, IG, Src):
455        t = []
456        iso = self.iso
457        for src in Src.nodes:
458            for dst in IG[src]:
459                Dst = iso(dst)
460                for rel in self.relations(src, dst):
461                    t.append((rel, Dst))
462        t.sort(key=lambda x: x[0])
463        return t
464
465    def prunedinverted(self, G, Y):
466        IG = self.nodegraph()
467        while Y:
468            R = G.domain_restricted(Y)
469            R.invert()
470            IG.update(R)
471            Y = R.get_domain()
472        return IG
473
474    def relation(self, src, dst):
475        tab = self.relations(src, dst)
476        if len(tab) > 1:
477            r = MultiRelation(tab)
478        elif not tab:
479            r = self.norelation
480        else:
481            r = tab[0]
482        return r
483
484    def relations(self, src, dst):
485        tab = []
486        if src is dst:
487            tab.append(self.identity)
488        rawrel = self.hv.relate(src, dst)
489        for i, rs in enumerate(rawrel):
490            for r in rs:
491                tab.append(self.rel_table[i](r))
492        if not tab:
493            tab = [self.norelation]
494        return tab
495
496    def shpaths(self, dst, src=None, avoid_nodes=None, avoid_edges=()):
497        return self.shpgraph([dst], src, avoid_nodes, avoid_edges)[0]
498
499    def shpgraph(self, DstSets, src=None, avoid_nodes=None, avoid_edges=(),
500                 srcname=None, dstname=None):
501        if src is None:
502            Src = self.iso(self.View.root)
503            if srcname is None and self.View.root is self.View.heapyc.RootState:
504                srcname = '%sRoot' % self.reprefix
505        else:
506            Src = self.idset_adapt(src)
507        if avoid_nodes is None:
508            AvoidNodes = self.Nothing
509        else:
510            AvoidNodes = self.idset_adapt(avoid_nodes)
511        AvoidEdges = self.nodegraph(avoid_edges)
512        G, DstSets = self.shpgraph_algorithm(
513            DstSets, Src, AvoidNodes, AvoidEdges)
514
515        return self.ShortestGraph(self, G, DstSets, Src, AvoidEdges,
516                                  srcname, dstname)
517
518    def shpgraph_algorithm(self, DstSets, Src, AvoidNodes, AvoidEdges):
519        U = (Src - AvoidNodes).nodes
520        S = self.mutnodeset(AvoidNodes.nodes)
521        G = self.nodegraph()
522        unseen = list(enumerate(DstSets))
523        DstSets = [self.Nothing]*len(DstSets)
524        while U and unseen:
525            S |= U
526            U = self.shpathstep(G, U, S, AvoidEdges)
527            unseen_ = []
528            for i, D in unseen:
529                D_ = D & U
530                if D_:
531                    DstSets[i] = D_
532                else:
533                    unseen_.append((i, D))
534            unseen = unseen_
535        return G, [self.idset_adapt(D) for D in DstSets]
536
537
538class _Specification_:
539    class GlueTypeExpr:
540        exec("""\
541shpgraph        <in>    callable
542""".replace('<in>', ' = lambda IN : '))
543