1"""
2****************
3ISMAGS Algorithm
4****************
5
6Provides a Python implementation of the ISMAGS algorithm. [1]_
7
8It is capable of finding (subgraph) isomorphisms between two graphs, taking the
9symmetry of the subgraph into account. In most cases the VF2 algorithm is
10faster (at least on small graphs) than this implementation, but in some cases
11there is an exponential number of isomorphisms that are symmetrically
12equivalent. In that case, the ISMAGS algorithm will provide only one solution
13per symmetry group.
14
15>>> petersen = nx.petersen_graph()
16>>> ismags = nx.isomorphism.ISMAGS(petersen, petersen)
17>>> isomorphisms = list(ismags.isomorphisms_iter(symmetry=False))
18>>> len(isomorphisms)
19120
20>>> isomorphisms = list(ismags.isomorphisms_iter(symmetry=True))
21>>> answer = [{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}]
22>>> answer == isomorphisms
23True
24
25In addition, this implementation also provides an interface to find the
26largest common induced subgraph [2]_ between any two graphs, again taking
27symmetry into account. Given `graph` and `subgraph` the algorithm will remove
28nodes from the `subgraph` until `subgraph` is isomorphic to a subgraph of
29`graph`. Since only the symmetry of `subgraph` is taken into account it is
30worth thinking about how you provide your graphs:
31
32>>> graph1 = nx.path_graph(4)
33>>> graph2 = nx.star_graph(3)
34>>> ismags = nx.isomorphism.ISMAGS(graph1, graph2)
35>>> ismags.is_isomorphic()
36False
37>>> largest_common_subgraph = list(ismags.largest_common_subgraph())
38>>> answer = [{1: 0, 0: 1, 2: 2}, {2: 0, 1: 1, 3: 2}]
39>>> answer == largest_common_subgraph
40True
41>>> ismags2 = nx.isomorphism.ISMAGS(graph2, graph1)
42>>> largest_common_subgraph = list(ismags2.largest_common_subgraph())
43>>> answer = [
44...     {1: 0, 0: 1, 2: 2},
45...     {1: 0, 0: 1, 3: 2},
46...     {2: 0, 0: 1, 1: 2},
47...     {2: 0, 0: 1, 3: 2},
48...     {3: 0, 0: 1, 1: 2},
49...     {3: 0, 0: 1, 2: 2},
50... ]
51>>> answer == largest_common_subgraph
52True
53
54However, when not taking symmetry into account, it doesn't matter:
55
56>>> largest_common_subgraph = list(ismags.largest_common_subgraph(symmetry=False))
57>>> answer = [
58...     {1: 0, 0: 1, 2: 2},
59...     {1: 0, 2: 1, 0: 2},
60...     {2: 0, 1: 1, 3: 2},
61...     {2: 0, 3: 1, 1: 2},
62...     {1: 0, 0: 1, 2: 3},
63...     {1: 0, 2: 1, 0: 3},
64...     {2: 0, 1: 1, 3: 3},
65...     {2: 0, 3: 1, 1: 3},
66...     {1: 0, 0: 2, 2: 3},
67...     {1: 0, 2: 2, 0: 3},
68...     {2: 0, 1: 2, 3: 3},
69...     {2: 0, 3: 2, 1: 3},
70... ]
71>>> answer == largest_common_subgraph
72True
73>>> largest_common_subgraph = list(ismags2.largest_common_subgraph(symmetry=False))
74>>> answer = [
75...     {1: 0, 0: 1, 2: 2},
76...     {1: 0, 0: 1, 3: 2},
77...     {2: 0, 0: 1, 1: 2},
78...     {2: 0, 0: 1, 3: 2},
79...     {3: 0, 0: 1, 1: 2},
80...     {3: 0, 0: 1, 2: 2},
81...     {1: 1, 0: 2, 2: 3},
82...     {1: 1, 0: 2, 3: 3},
83...     {2: 1, 0: 2, 1: 3},
84...     {2: 1, 0: 2, 3: 3},
85...     {3: 1, 0: 2, 1: 3},
86...     {3: 1, 0: 2, 2: 3},
87... ]
88>>> answer == largest_common_subgraph
89True
90
91Notes
92-----
93 - The current implementation works for undirected graphs only. The algorithm
94   in general should work for directed graphs as well though.
95 - Node keys for both provided graphs need to be fully orderable as well as
96   hashable.
97 - Node and edge equality is assumed to be transitive: if A is equal to B, and
98   B is equal to C, then A is equal to C.
99
100References
101----------
102    .. [1] M. Houbraken, S. Demeyer, T. Michoel, P. Audenaert, D. Colle,
103       M. Pickavet, "The Index-Based Subgraph Matching Algorithm with General
104       Symmetries (ISMAGS): Exploiting Symmetry for Faster Subgraph
105       Enumeration", PLoS One 9(5): e97896, 2014.
106       https://doi.org/10.1371/journal.pone.0097896
107    .. [2] https://en.wikipedia.org/wiki/Maximum_common_induced_subgraph
108"""
109
110__all__ = ["ISMAGS"]
111
112from collections import defaultdict, Counter
113from functools import reduce, wraps
114import itertools
115
116
117def are_all_equal(iterable):
118    """
119    Returns ``True`` if and only if all elements in `iterable` are equal; and
120    ``False`` otherwise.
121
122    Parameters
123    ----------
124    iterable: collections.abc.Iterable
125        The container whose elements will be checked.
126
127    Returns
128    -------
129    bool
130        ``True`` iff all elements in `iterable` compare equal, ``False``
131        otherwise.
132    """
133    try:
134        shape = iterable.shape
135    except AttributeError:
136        pass
137    else:
138        if len(shape) > 1:
139            message = "The function does not works on multidimension arrays."
140            raise NotImplementedError(message) from None
141
142    iterator = iter(iterable)
143    first = next(iterator, None)
144    return all(item == first for item in iterator)
145
146
147def make_partitions(items, test):
148    """
149    Partitions items into sets based on the outcome of ``test(item1, item2)``.
150    Pairs of items for which `test` returns `True` end up in the same set.
151
152    Parameters
153    ----------
154    items : collections.abc.Iterable[collections.abc.Hashable]
155        Items to partition
156    test : collections.abc.Callable[collections.abc.Hashable, collections.abc.Hashable]
157        A function that will be called with 2 arguments, taken from items.
158        Should return `True` if those 2 items need to end up in the same
159        partition, and `False` otherwise.
160
161    Returns
162    -------
163    list[set]
164        A list of sets, with each set containing part of the items in `items`,
165        such that ``all(test(*pair) for pair in  itertools.combinations(set, 2))
166        == True``
167
168    Notes
169    -----
170    The function `test` is assumed to be transitive: if ``test(a, b)`` and
171    ``test(b, c)`` return ``True``, then ``test(a, c)`` must also be ``True``.
172    """
173    partitions = []
174    for item in items:
175        for partition in partitions:
176            p_item = next(iter(partition))
177            if test(item, p_item):
178                partition.add(item)
179                break
180        else:  # No break
181            partitions.append({item})
182    return partitions
183
184
185def partition_to_color(partitions):
186    """
187    Creates a dictionary with for every item in partition for every partition
188    in partitions the index of partition in partitions.
189
190    Parameters
191    ----------
192    partitions: collections.abc.Sequence[collections.abc.Iterable]
193        As returned by :func:`make_partitions`.
194
195    Returns
196    -------
197    dict
198    """
199    colors = dict()
200    for color, keys in enumerate(partitions):
201        for key in keys:
202            colors[key] = color
203    return colors
204
205
206def intersect(collection_of_sets):
207    """
208    Given an collection of sets, returns the intersection of those sets.
209
210    Parameters
211    ----------
212    collection_of_sets: collections.abc.Collection[set]
213        A collection of sets.
214
215    Returns
216    -------
217    set
218        An intersection of all sets in `collection_of_sets`. Will have the same
219        type as the item initially taken from `collection_of_sets`.
220    """
221    collection_of_sets = list(collection_of_sets)
222    first = collection_of_sets.pop()
223    out = reduce(set.intersection, collection_of_sets, set(first))
224    return type(first)(out)
225
226
227class ISMAGS:
228    """
229    Implements the ISMAGS subgraph matching algorith. [1]_ ISMAGS stands for
230    "Index-based Subgraph Matching Algorithm with General Symmetries". As the
231    name implies, it is symmetry aware and will only generate non-symmetric
232    isomorphisms.
233
234    Notes
235    -----
236    The implementation imposes additional conditions compared to the VF2
237    algorithm on the graphs provided and the comparison functions
238    (:attr:`node_equality` and :attr:`edge_equality`):
239
240     - Node keys in both graphs must be orderable as well as hashable.
241     - Equality must be transitive: if A is equal to B, and B is equal to C,
242       then A must be equal to C.
243
244    Attributes
245    ----------
246    graph: networkx.Graph
247    subgraph: networkx.Graph
248    node_equality: collections.abc.Callable
249        The function called to see if two nodes should be considered equal.
250        It's signature looks like this:
251        ``f(graph1: networkx.Graph, node1, graph2: networkx.Graph, node2) -> bool``.
252        `node1` is a node in `graph1`, and `node2` a node in `graph2`.
253        Constructed from the argument `node_match`.
254    edge_equality: collections.abc.Callable
255        The function called to see if two edges should be considered equal.
256        It's signature looks like this:
257        ``f(graph1: networkx.Graph, edge1, graph2: networkx.Graph, edge2) -> bool``.
258        `edge1` is an edge in `graph1`, and `edge2` an edge in `graph2`.
259        Constructed from the argument `edge_match`.
260
261    References
262    ----------
263    .. [1] M. Houbraken, S. Demeyer, T. Michoel, P. Audenaert, D. Colle,
264       M. Pickavet, "The Index-Based Subgraph Matching Algorithm with General
265       Symmetries (ISMAGS): Exploiting Symmetry for Faster Subgraph
266       Enumeration", PLoS One 9(5): e97896, 2014.
267       https://doi.org/10.1371/journal.pone.0097896
268    """
269
270    def __init__(self, graph, subgraph, node_match=None, edge_match=None, cache=None):
271        """
272        Parameters
273        ----------
274        graph: networkx.Graph
275        subgraph: networkx.Graph
276        node_match: collections.abc.Callable or None
277            Function used to determine whether two nodes are equivalent. Its
278            signature should look like ``f(n1: dict, n2: dict) -> bool``, with
279            `n1` and `n2` node property dicts. See also
280            :func:`~networkx.algorithms.isomorphism.categorical_node_match` and
281            friends.
282            If `None`, all nodes are considered equal.
283        edge_match: collections.abc.Callable or None
284            Function used to determine whether two edges are equivalent. Its
285            signature should look like ``f(e1: dict, e2: dict) -> bool``, with
286            `e1` and `e2` edge property dicts. See also
287            :func:`~networkx.algorithms.isomorphism.categorical_edge_match` and
288            friends.
289            If `None`, all edges are considered equal.
290        cache: collections.abc.Mapping
291            A cache used for caching graph symmetries.
292        """
293        # TODO: graph and subgraph setter methods that invalidate the caches.
294        # TODO: allow for precomputed partitions and colors
295        self.graph = graph
296        self.subgraph = subgraph
297        self._symmetry_cache = cache
298        # Naming conventions are taken from the original paper. For your
299        # sanity:
300        #   sg: subgraph
301        #   g: graph
302        #   e: edge(s)
303        #   n: node(s)
304        # So: sgn means "subgraph nodes".
305        self._sgn_partitions_ = None
306        self._sge_partitions_ = None
307
308        self._sgn_colors_ = None
309        self._sge_colors_ = None
310
311        self._gn_partitions_ = None
312        self._ge_partitions_ = None
313
314        self._gn_colors_ = None
315        self._ge_colors_ = None
316
317        self._node_compat_ = None
318        self._edge_compat_ = None
319
320        if node_match is None:
321            self.node_equality = self._node_match_maker(lambda n1, n2: True)
322            self._sgn_partitions_ = [set(self.subgraph.nodes)]
323            self._gn_partitions_ = [set(self.graph.nodes)]
324            self._node_compat_ = {0: 0}
325        else:
326            self.node_equality = self._node_match_maker(node_match)
327        if edge_match is None:
328            self.edge_equality = self._edge_match_maker(lambda e1, e2: True)
329            self._sge_partitions_ = [set(self.subgraph.edges)]
330            self._ge_partitions_ = [set(self.graph.edges)]
331            self._edge_compat_ = {0: 0}
332        else:
333            self.edge_equality = self._edge_match_maker(edge_match)
334
335    @property
336    def _sgn_partitions(self):
337        if self._sgn_partitions_ is None:
338
339            def nodematch(node1, node2):
340                return self.node_equality(self.subgraph, node1, self.subgraph, node2)
341
342            self._sgn_partitions_ = make_partitions(self.subgraph.nodes, nodematch)
343        return self._sgn_partitions_
344
345    @property
346    def _sge_partitions(self):
347        if self._sge_partitions_ is None:
348
349            def edgematch(edge1, edge2):
350                return self.edge_equality(self.subgraph, edge1, self.subgraph, edge2)
351
352            self._sge_partitions_ = make_partitions(self.subgraph.edges, edgematch)
353        return self._sge_partitions_
354
355    @property
356    def _gn_partitions(self):
357        if self._gn_partitions_ is None:
358
359            def nodematch(node1, node2):
360                return self.node_equality(self.graph, node1, self.graph, node2)
361
362            self._gn_partitions_ = make_partitions(self.graph.nodes, nodematch)
363        return self._gn_partitions_
364
365    @property
366    def _ge_partitions(self):
367        if self._ge_partitions_ is None:
368
369            def edgematch(edge1, edge2):
370                return self.edge_equality(self.graph, edge1, self.graph, edge2)
371
372            self._ge_partitions_ = make_partitions(self.graph.edges, edgematch)
373        return self._ge_partitions_
374
375    @property
376    def _sgn_colors(self):
377        if self._sgn_colors_ is None:
378            self._sgn_colors_ = partition_to_color(self._sgn_partitions)
379        return self._sgn_colors_
380
381    @property
382    def _sge_colors(self):
383        if self._sge_colors_ is None:
384            self._sge_colors_ = partition_to_color(self._sge_partitions)
385        return self._sge_colors_
386
387    @property
388    def _gn_colors(self):
389        if self._gn_colors_ is None:
390            self._gn_colors_ = partition_to_color(self._gn_partitions)
391        return self._gn_colors_
392
393    @property
394    def _ge_colors(self):
395        if self._ge_colors_ is None:
396            self._ge_colors_ = partition_to_color(self._ge_partitions)
397        return self._ge_colors_
398
399    @property
400    def _node_compatibility(self):
401        if self._node_compat_ is not None:
402            return self._node_compat_
403        self._node_compat_ = {}
404        for sgn_part_color, gn_part_color in itertools.product(
405            range(len(self._sgn_partitions)), range(len(self._gn_partitions))
406        ):
407            sgn = next(iter(self._sgn_partitions[sgn_part_color]))
408            gn = next(iter(self._gn_partitions[gn_part_color]))
409            if self.node_equality(self.subgraph, sgn, self.graph, gn):
410                self._node_compat_[sgn_part_color] = gn_part_color
411        return self._node_compat_
412
413    @property
414    def _edge_compatibility(self):
415        if self._edge_compat_ is not None:
416            return self._edge_compat_
417        self._edge_compat_ = {}
418        for sge_part_color, ge_part_color in itertools.product(
419            range(len(self._sge_partitions)), range(len(self._ge_partitions))
420        ):
421            sge = next(iter(self._sge_partitions[sge_part_color]))
422            ge = next(iter(self._ge_partitions[ge_part_color]))
423            if self.edge_equality(self.subgraph, sge, self.graph, ge):
424                self._edge_compat_[sge_part_color] = ge_part_color
425        return self._edge_compat_
426
427    @staticmethod
428    def _node_match_maker(cmp):
429        @wraps(cmp)
430        def comparer(graph1, node1, graph2, node2):
431            return cmp(graph1.nodes[node1], graph2.nodes[node2])
432
433        return comparer
434
435    @staticmethod
436    def _edge_match_maker(cmp):
437        @wraps(cmp)
438        def comparer(graph1, edge1, graph2, edge2):
439            return cmp(graph1.edges[edge1], graph2.edges[edge2])
440
441        return comparer
442
443    def find_isomorphisms(self, symmetry=True):
444        """Find all subgraph isomorphisms between subgraph and graph
445
446        Finds isomorphisms where :attr:`subgraph` <= :attr:`graph`.
447
448        Parameters
449        ----------
450        symmetry: bool
451            Whether symmetry should be taken into account. If False, found
452            isomorphisms may be symmetrically equivalent.
453
454        Yields
455        ------
456        dict
457            The found isomorphism mappings of {graph_node: subgraph_node}.
458        """
459        # The networkx VF2 algorithm is slightly funny in when it yields an
460        # empty dict and when not.
461        if not self.subgraph:
462            yield {}
463            return
464        elif not self.graph:
465            return
466        elif len(self.graph) < len(self.subgraph):
467            return
468
469        if symmetry:
470            _, cosets = self.analyze_symmetry(
471                self.subgraph, self._sgn_partitions, self._sge_colors
472            )
473            constraints = self._make_constraints(cosets)
474        else:
475            constraints = []
476
477        candidates = self._find_nodecolor_candidates()
478        la_candidates = self._get_lookahead_candidates()
479        for sgn in self.subgraph:
480            extra_candidates = la_candidates[sgn]
481            if extra_candidates:
482                candidates[sgn] = candidates[sgn] | {frozenset(extra_candidates)}
483
484        if any(candidates.values()):
485            start_sgn = min(candidates, key=lambda n: min(candidates[n], key=len))
486            candidates[start_sgn] = (intersect(candidates[start_sgn]),)
487            yield from self._map_nodes(start_sgn, candidates, constraints)
488        else:
489            return
490
491    @staticmethod
492    def _find_neighbor_color_count(graph, node, node_color, edge_color):
493        """
494        For `node` in `graph`, count the number of edges of a specific color
495        it has to nodes of a specific color.
496        """
497        counts = Counter()
498        neighbors = graph[node]
499        for neighbor in neighbors:
500            n_color = node_color[neighbor]
501            if (node, neighbor) in edge_color:
502                e_color = edge_color[node, neighbor]
503            else:
504                e_color = edge_color[neighbor, node]
505            counts[e_color, n_color] += 1
506        return counts
507
508    def _get_lookahead_candidates(self):
509        """
510        Returns a mapping of {subgraph node: collection of graph nodes} for
511        which the graph nodes are feasible candidates for the subgraph node, as
512        determined by looking ahead one edge.
513        """
514        g_counts = {}
515        for gn in self.graph:
516            g_counts[gn] = self._find_neighbor_color_count(
517                self.graph, gn, self._gn_colors, self._ge_colors
518            )
519        candidates = defaultdict(set)
520        for sgn in self.subgraph:
521            sg_count = self._find_neighbor_color_count(
522                self.subgraph, sgn, self._sgn_colors, self._sge_colors
523            )
524            new_sg_count = Counter()
525            for (sge_color, sgn_color), count in sg_count.items():
526                try:
527                    ge_color = self._edge_compatibility[sge_color]
528                    gn_color = self._node_compatibility[sgn_color]
529                except KeyError:
530                    pass
531                else:
532                    new_sg_count[ge_color, gn_color] = count
533
534            for gn, g_count in g_counts.items():
535                if all(new_sg_count[x] <= g_count[x] for x in new_sg_count):
536                    # Valid candidate
537                    candidates[sgn].add(gn)
538        return candidates
539
540    def largest_common_subgraph(self, symmetry=True):
541        """
542        Find the largest common induced subgraphs between :attr:`subgraph` and
543        :attr:`graph`.
544
545        Parameters
546        ----------
547        symmetry: bool
548            Whether symmetry should be taken into account. If False, found
549            largest common subgraphs may be symmetrically equivalent.
550
551        Yields
552        ------
553        dict
554            The found isomorphism mappings of {graph_node: subgraph_node}.
555        """
556        # The networkx VF2 algorithm is slightly funny in when it yields an
557        # empty dict and when not.
558        if not self.subgraph:
559            yield {}
560            return
561        elif not self.graph:
562            return
563
564        if symmetry:
565            _, cosets = self.analyze_symmetry(
566                self.subgraph, self._sgn_partitions, self._sge_colors
567            )
568            constraints = self._make_constraints(cosets)
569        else:
570            constraints = []
571
572        candidates = self._find_nodecolor_candidates()
573
574        if any(candidates.values()):
575            yield from self._largest_common_subgraph(candidates, constraints)
576        else:
577            return
578
579    def analyze_symmetry(self, graph, node_partitions, edge_colors):
580        """
581        Find a minimal set of permutations and corresponding co-sets that
582        describe the symmetry of :attr:`subgraph`.
583
584        Returns
585        -------
586        set[frozenset]
587            The found permutations. This is a set of frozenset of pairs of node
588            keys which can be exchanged without changing :attr:`subgraph`.
589        dict[collections.abc.Hashable, set[collections.abc.Hashable]]
590            The found co-sets. The co-sets is a dictionary of {node key:
591            set of node keys}. Every key-value pair describes which `values`
592            can be interchanged without changing nodes less than `key`.
593        """
594        if self._symmetry_cache is not None:
595            key = hash(
596                (
597                    tuple(graph.nodes),
598                    tuple(graph.edges),
599                    tuple(map(tuple, node_partitions)),
600                    tuple(edge_colors.items()),
601                )
602            )
603            if key in self._symmetry_cache:
604                return self._symmetry_cache[key]
605        node_partitions = list(
606            self._refine_node_partitions(graph, node_partitions, edge_colors)
607        )
608        assert len(node_partitions) == 1
609        node_partitions = node_partitions[0]
610        permutations, cosets = self._process_ordered_pair_partitions(
611            graph, node_partitions, node_partitions, edge_colors
612        )
613        if self._symmetry_cache is not None:
614            self._symmetry_cache[key] = permutations, cosets
615        return permutations, cosets
616
617    def is_isomorphic(self, symmetry=False):
618        """
619        Returns True if :attr:`graph` is isomorphic to :attr:`subgraph` and
620        False otherwise.
621
622        Returns
623        -------
624        bool
625        """
626        return len(self.subgraph) == len(self.graph) and self.subgraph_is_isomorphic(
627            symmetry
628        )
629
630    def subgraph_is_isomorphic(self, symmetry=False):
631        """
632        Returns True if a subgraph of :attr:`graph` is isomorphic to
633        :attr:`subgraph` and False otherwise.
634
635        Returns
636        -------
637        bool
638        """
639        # symmetry=False, since we only need to know whether there is any
640        # example; figuring out all symmetry elements probably costs more time
641        # than it gains.
642        isom = next(self.subgraph_isomorphisms_iter(symmetry=symmetry), None)
643        return isom is not None
644
645    def isomorphisms_iter(self, symmetry=True):
646        """
647        Does the same as :meth:`find_isomorphisms` if :attr:`graph` and
648        :attr:`subgraph` have the same number of nodes.
649        """
650        if len(self.graph) == len(self.subgraph):
651            yield from self.subgraph_isomorphisms_iter(symmetry=symmetry)
652
653    def subgraph_isomorphisms_iter(self, symmetry=True):
654        """Alternative name for :meth:`find_isomorphisms`."""
655        return self.find_isomorphisms(symmetry)
656
657    def _find_nodecolor_candidates(self):
658        """
659        Per node in subgraph find all nodes in graph that have the same color.
660        """
661        candidates = defaultdict(set)
662        for sgn in self.subgraph.nodes:
663            sgn_color = self._sgn_colors[sgn]
664            if sgn_color in self._node_compatibility:
665                gn_color = self._node_compatibility[sgn_color]
666                candidates[sgn].add(frozenset(self._gn_partitions[gn_color]))
667            else:
668                candidates[sgn].add(frozenset())
669        candidates = dict(candidates)
670        for sgn, options in candidates.items():
671            candidates[sgn] = frozenset(options)
672        return candidates
673
674    @staticmethod
675    def _make_constraints(cosets):
676        """
677        Turn cosets into constraints.
678        """
679        constraints = []
680        for node_i, node_ts in cosets.items():
681            for node_t in node_ts:
682                if node_i != node_t:
683                    # Node i must be smaller than node t.
684                    constraints.append((node_i, node_t))
685        return constraints
686
687    @staticmethod
688    def _find_node_edge_color(graph, node_colors, edge_colors):
689        """
690        For every node in graph, come up with a color that combines 1) the
691        color of the node, and 2) the number of edges of a color to each type
692        of node.
693        """
694        counts = defaultdict(lambda: defaultdict(int))
695        for node1, node2 in graph.edges:
696            if (node1, node2) in edge_colors:
697                # FIXME directed graphs
698                ecolor = edge_colors[node1, node2]
699            else:
700                ecolor = edge_colors[node2, node1]
701            # Count per node how many edges it has of what color to nodes of
702            # what color
703            counts[node1][ecolor, node_colors[node2]] += 1
704            counts[node2][ecolor, node_colors[node1]] += 1
705
706        node_edge_colors = dict()
707        for node in graph.nodes:
708            node_edge_colors[node] = node_colors[node], set(counts[node].items())
709
710        return node_edge_colors
711
712    @staticmethod
713    def _get_permutations_by_length(items):
714        """
715        Get all permutations of items, but only permute items with the same
716        length.
717
718        >>> found = list(ISMAGS._get_permutations_by_length([[1], [2], [3, 4], [4, 5]]))
719        >>> answer = [
720        ...     (([1], [2]), ([3, 4], [4, 5])),
721        ...     (([1], [2]), ([4, 5], [3, 4])),
722        ...     (([2], [1]), ([3, 4], [4, 5])),
723        ...     (([2], [1]), ([4, 5], [3, 4])),
724        ... ]
725        >>> found == answer
726        True
727        """
728        by_len = defaultdict(list)
729        for item in items:
730            by_len[len(item)].append(item)
731
732        yield from itertools.product(
733            *(itertools.permutations(by_len[l]) for l in sorted(by_len))
734        )
735
736    @classmethod
737    def _refine_node_partitions(cls, graph, node_partitions, edge_colors, branch=False):
738        """
739        Given a partition of nodes in graph, make the partitions smaller such
740        that all nodes in a partition have 1) the same color, and 2) the same
741        number of edges to specific other partitions.
742        """
743
744        def equal_color(node1, node2):
745            return node_edge_colors[node1] == node_edge_colors[node2]
746
747        node_partitions = list(node_partitions)
748        node_colors = partition_to_color(node_partitions)
749        node_edge_colors = cls._find_node_edge_color(graph, node_colors, edge_colors)
750        if all(
751            are_all_equal(node_edge_colors[node] for node in partition)
752            for partition in node_partitions
753        ):
754            yield node_partitions
755            return
756
757        new_partitions = []
758        output = [new_partitions]
759        for partition in node_partitions:
760            if not are_all_equal(node_edge_colors[node] for node in partition):
761                refined = make_partitions(partition, equal_color)
762                if (
763                    branch
764                    and len(refined) != 1
765                    and len({len(r) for r in refined}) != len([len(r) for r in refined])
766                ):
767                    # This is where it breaks. There are multiple new cells
768                    # in refined with the same length, and their order
769                    # matters.
770                    # So option 1) Hit it with a big hammer and simply make all
771                    # orderings.
772                    permutations = cls._get_permutations_by_length(refined)
773                    new_output = []
774                    for n_p in output:
775                        for permutation in permutations:
776                            new_output.append(n_p + list(permutation[0]))
777                    output = new_output
778                else:
779                    for n_p in output:
780                        n_p.extend(sorted(refined, key=len))
781            else:
782                for n_p in output:
783                    n_p.append(partition)
784        for n_p in output:
785            yield from cls._refine_node_partitions(graph, n_p, edge_colors, branch)
786
787    def _edges_of_same_color(self, sgn1, sgn2):
788        """
789        Returns all edges in :attr:`graph` that have the same colour as the
790        edge between sgn1 and sgn2 in :attr:`subgraph`.
791        """
792        if (sgn1, sgn2) in self._sge_colors:
793            # FIXME directed graphs
794            sge_color = self._sge_colors[sgn1, sgn2]
795        else:
796            sge_color = self._sge_colors[sgn2, sgn1]
797        if sge_color in self._edge_compatibility:
798            ge_color = self._edge_compatibility[sge_color]
799            g_edges = self._ge_partitions[ge_color]
800        else:
801            g_edges = []
802        return g_edges
803
804    def _map_nodes(self, sgn, candidates, constraints, mapping=None, to_be_mapped=None):
805        """
806        Find all subgraph isomorphisms honoring constraints.
807        """
808        if mapping is None:
809            mapping = {}
810        else:
811            mapping = mapping.copy()
812        if to_be_mapped is None:
813            to_be_mapped = set(self.subgraph.nodes)
814
815        # Note, we modify candidates here. Doesn't seem to affect results, but
816        # remember this.
817        # candidates = candidates.copy()
818        sgn_candidates = intersect(candidates[sgn])
819        candidates[sgn] = frozenset([sgn_candidates])
820        for gn in sgn_candidates:
821            # We're going to try to map sgn to gn.
822            if gn in mapping.values() or sgn not in to_be_mapped:
823                # gn is already mapped to something
824                continue  # pragma: no cover
825
826            # REDUCTION and COMBINATION
827            mapping[sgn] = gn
828            # BASECASE
829            if to_be_mapped == set(mapping.keys()):
830                yield {v: k for k, v in mapping.items()}
831                continue
832            left_to_map = to_be_mapped - set(mapping.keys())
833
834            new_candidates = candidates.copy()
835            sgn_neighbours = set(self.subgraph[sgn])
836            not_gn_neighbours = set(self.graph.nodes) - set(self.graph[gn])
837            for sgn2 in left_to_map:
838                if sgn2 not in sgn_neighbours:
839                    gn2_options = not_gn_neighbours
840                else:
841                    # Get all edges to gn of the right color:
842                    g_edges = self._edges_of_same_color(sgn, sgn2)
843                    # FIXME directed graphs
844                    # And all nodes involved in those which are connected to gn
845                    gn2_options = {n for e in g_edges for n in e if gn in e}
846                # Node color compatibility should be taken care of by the
847                # initial candidate lists made by find_subgraphs
848
849                # Add gn2_options to the right collection. Since new_candidates
850                # is a dict of frozensets of frozensets of node indices it's
851                # a bit clunky. We can't do .add, and + also doesn't work. We
852                # could do |, but I deem union to be clearer.
853                new_candidates[sgn2] = new_candidates[sgn2].union(
854                    [frozenset(gn2_options)]
855                )
856
857                if (sgn, sgn2) in constraints:
858                    gn2_options = {gn2 for gn2 in self.graph if gn2 > gn}
859                elif (sgn2, sgn) in constraints:
860                    gn2_options = {gn2 for gn2 in self.graph if gn2 < gn}
861                else:
862                    continue  # pragma: no cover
863                new_candidates[sgn2] = new_candidates[sgn2].union(
864                    [frozenset(gn2_options)]
865                )
866
867            # The next node is the one that is unmapped and has fewest
868            # candidates
869            # Pylint disables because it's a one-shot function.
870            next_sgn = min(
871                left_to_map, key=lambda n: min(new_candidates[n], key=len)
872            )  # pylint: disable=cell-var-from-loop
873            yield from self._map_nodes(
874                next_sgn,
875                new_candidates,
876                constraints,
877                mapping=mapping,
878                to_be_mapped=to_be_mapped,
879            )
880            # Unmap sgn-gn. Strictly not necessary since it'd get overwritten
881            # when making a new mapping for sgn.
882            # del mapping[sgn]
883
884    def _largest_common_subgraph(self, candidates, constraints, to_be_mapped=None):
885        """
886        Find all largest common subgraphs honoring constraints.
887        """
888        if to_be_mapped is None:
889            to_be_mapped = {frozenset(self.subgraph.nodes)}
890
891        # The LCS problem is basically a repeated subgraph isomorphism problem
892        # with smaller and smaller subgraphs. We store the nodes that are
893        # "part of" the subgraph in to_be_mapped, and we make it a little
894        # smaller every iteration.
895
896        # pylint disable becuase it's guarded against by default value
897        current_size = len(
898            next(iter(to_be_mapped), [])
899        )  # pylint: disable=stop-iteration-return
900
901        found_iso = False
902        if current_size <= len(self.graph):
903            # There's no point in trying to find isomorphisms of
904            # graph >= subgraph if subgraph has more nodes than graph.
905
906            # Try the isomorphism first with the nodes with lowest ID. So sort
907            # them. Those are more likely to be part of the final
908            # correspondence. This makes finding the first answer(s) faster. In
909            # theory.
910            for nodes in sorted(to_be_mapped, key=sorted):
911                # Find the isomorphism between subgraph[to_be_mapped] <= graph
912                next_sgn = min(nodes, key=lambda n: min(candidates[n], key=len))
913                isomorphs = self._map_nodes(
914                    next_sgn, candidates, constraints, to_be_mapped=nodes
915                )
916
917                # This is effectively `yield from isomorphs`, except that we look
918                # whether an item was yielded.
919                try:
920                    item = next(isomorphs)
921                except StopIteration:
922                    pass
923                else:
924                    yield item
925                    yield from isomorphs
926                    found_iso = True
927
928        # BASECASE
929        if found_iso or current_size == 1:
930            # Shrinking has no point because either 1) we end up with a smaller
931            # common subgraph (and we want the largest), or 2) there'll be no
932            # more subgraph.
933            return
934
935        left_to_be_mapped = set()
936        for nodes in to_be_mapped:
937            for sgn in nodes:
938                # We're going to remove sgn from to_be_mapped, but subject to
939                # symmetry constraints. We know that for every constraint we
940                # have those subgraph nodes are equal. So whenever we would
941                # remove the lower part of a constraint, remove the higher
942                # instead. This is all dealth with by _remove_node. And because
943                # left_to_be_mapped is a set, we don't do double work.
944
945                # And finally, make the subgraph one node smaller.
946                # REDUCTION
947                new_nodes = self._remove_node(sgn, nodes, constraints)
948                left_to_be_mapped.add(new_nodes)
949        # COMBINATION
950        yield from self._largest_common_subgraph(
951            candidates, constraints, to_be_mapped=left_to_be_mapped
952        )
953
954    @staticmethod
955    def _remove_node(node, nodes, constraints):
956        """
957        Returns a new set where node has been removed from nodes, subject to
958        symmetry constraints. We know, that for every constraint we have
959        those subgraph nodes are equal. So whenever we would remove the
960        lower part of a constraint, remove the higher instead.
961        """
962        while True:
963            for low, high in constraints:
964                if low == node and high in nodes:
965                    node = high
966                    break
967            else:  # no break, couldn't find node in constraints
968                break
969        return frozenset(nodes - {node})
970
971    @staticmethod
972    def _find_permutations(top_partitions, bottom_partitions):
973        """
974        Return the pairs of top/bottom partitions where the partitions are
975        different. Ensures that all partitions in both top and bottom
976        partitions have size 1.
977        """
978        # Find permutations
979        permutations = set()
980        for top, bot in zip(top_partitions, bottom_partitions):
981            # top and bot have only one element
982            if len(top) != 1 or len(bot) != 1:
983                raise IndexError(
984                    "Not all nodes are coupled. This is"
985                    f" impossible: {top_partitions}, {bottom_partitions}"
986                )
987            if top != bot:
988                permutations.add(frozenset((next(iter(top)), next(iter(bot)))))
989        return permutations
990
991    @staticmethod
992    def _update_orbits(orbits, permutations):
993        """
994        Update orbits based on permutations. Orbits is modified in place.
995        For every pair of items in permutations their respective orbits are
996        merged.
997        """
998        for permutation in permutations:
999            node, node2 = permutation
1000            # Find the orbits that contain node and node2, and replace the
1001            # orbit containing node with the union
1002            first = second = None
1003            for idx, orbit in enumerate(orbits):
1004                if first is not None and second is not None:
1005                    break
1006                if node in orbit:
1007                    first = idx
1008                if node2 in orbit:
1009                    second = idx
1010            if first != second:
1011                orbits[first].update(orbits[second])
1012                del orbits[second]
1013
1014    def _couple_nodes(
1015        self,
1016        top_partitions,
1017        bottom_partitions,
1018        pair_idx,
1019        t_node,
1020        b_node,
1021        graph,
1022        edge_colors,
1023    ):
1024        """
1025        Generate new partitions from top and bottom_partitions where t_node is
1026        coupled to b_node. pair_idx is the index of the partitions where t_ and
1027        b_node can be found.
1028        """
1029        t_partition = top_partitions[pair_idx]
1030        b_partition = bottom_partitions[pair_idx]
1031        assert t_node in t_partition and b_node in b_partition
1032        # Couple node to node2. This means they get their own partition
1033        new_top_partitions = [top.copy() for top in top_partitions]
1034        new_bottom_partitions = [bot.copy() for bot in bottom_partitions]
1035        new_t_groups = {t_node}, t_partition - {t_node}
1036        new_b_groups = {b_node}, b_partition - {b_node}
1037        # Replace the old partitions with the coupled ones
1038        del new_top_partitions[pair_idx]
1039        del new_bottom_partitions[pair_idx]
1040        new_top_partitions[pair_idx:pair_idx] = new_t_groups
1041        new_bottom_partitions[pair_idx:pair_idx] = new_b_groups
1042
1043        new_top_partitions = self._refine_node_partitions(
1044            graph, new_top_partitions, edge_colors
1045        )
1046        new_bottom_partitions = self._refine_node_partitions(
1047            graph, new_bottom_partitions, edge_colors, branch=True
1048        )
1049        new_top_partitions = list(new_top_partitions)
1050        assert len(new_top_partitions) == 1
1051        new_top_partitions = new_top_partitions[0]
1052        for bot in new_bottom_partitions:
1053            yield list(new_top_partitions), bot
1054
1055    def _process_ordered_pair_partitions(
1056        self,
1057        graph,
1058        top_partitions,
1059        bottom_partitions,
1060        edge_colors,
1061        orbits=None,
1062        cosets=None,
1063    ):
1064        """
1065        Processes ordered pair partitions as per the reference paper. Finds and
1066        returns all permutations and cosets that leave the graph unchanged.
1067        """
1068        if orbits is None:
1069            orbits = [{node} for node in graph.nodes]
1070        else:
1071            # Note that we don't copy orbits when we are given one. This means
1072            # we leak information between the recursive branches. This is
1073            # intentional!
1074            orbits = orbits
1075        if cosets is None:
1076            cosets = {}
1077        else:
1078            cosets = cosets.copy()
1079
1080        assert all(
1081            len(t_p) == len(b_p) for t_p, b_p in zip(top_partitions, bottom_partitions)
1082        )
1083
1084        # BASECASE
1085        if all(len(top) == 1 for top in top_partitions):
1086            # All nodes are mapped
1087            permutations = self._find_permutations(top_partitions, bottom_partitions)
1088            self._update_orbits(orbits, permutations)
1089            if permutations:
1090                return [permutations], cosets
1091            else:
1092                return [], cosets
1093
1094        permutations = []
1095        unmapped_nodes = {
1096            (node, idx)
1097            for idx, t_partition in enumerate(top_partitions)
1098            for node in t_partition
1099            if len(t_partition) > 1
1100        }
1101        node, pair_idx = min(unmapped_nodes)
1102        b_partition = bottom_partitions[pair_idx]
1103
1104        for node2 in sorted(b_partition):
1105            if len(b_partition) == 1:
1106                # Can never result in symmetry
1107                continue
1108            if node != node2 and any(
1109                node in orbit and node2 in orbit for orbit in orbits
1110            ):
1111                # Orbit prune branch
1112                continue
1113            # REDUCTION
1114            # Couple node to node2
1115            partitions = self._couple_nodes(
1116                top_partitions,
1117                bottom_partitions,
1118                pair_idx,
1119                node,
1120                node2,
1121                graph,
1122                edge_colors,
1123            )
1124            for opp in partitions:
1125                new_top_partitions, new_bottom_partitions = opp
1126
1127                new_perms, new_cosets = self._process_ordered_pair_partitions(
1128                    graph,
1129                    new_top_partitions,
1130                    new_bottom_partitions,
1131                    edge_colors,
1132                    orbits,
1133                    cosets,
1134                )
1135                # COMBINATION
1136                permutations += new_perms
1137                cosets.update(new_cosets)
1138
1139        mapped = {
1140            k
1141            for top, bottom in zip(top_partitions, bottom_partitions)
1142            for k in top
1143            if len(top) == 1 and top == bottom
1144        }
1145        ks = {k for k in graph.nodes if k < node}
1146        # Have all nodes with ID < node been mapped?
1147        find_coset = ks <= mapped and node not in cosets
1148        if find_coset:
1149            # Find the orbit that contains node
1150            for orbit in orbits:
1151                if node in orbit:
1152                    cosets[node] = orbit.copy()
1153        return permutations, cosets
1154