1"""
2Equitable coloring of graphs with bounded degree.
3"""
4
5import networkx as nx
6from collections import defaultdict
7
8__all__ = ["equitable_color"]
9
10
11def is_coloring(G, coloring):
12    """Determine if the coloring is a valid coloring for the graph G."""
13    # Verify that the coloring is valid.
14    for (s, d) in G.edges:
15        if coloring[s] == coloring[d]:
16            return False
17    return True
18
19
20def is_equitable(G, coloring, num_colors=None):
21    """Determines if the coloring is valid and equitable for the graph G."""
22
23    if not is_coloring(G, coloring):
24        return False
25
26    # Verify whether it is equitable.
27    color_set_size = defaultdict(int)
28    for color in coloring.values():
29        color_set_size[color] += 1
30
31    if num_colors is not None:
32        for color in range(num_colors):
33            if color not in color_set_size:
34                # These colors do not have any vertices attached to them.
35                color_set_size[color] = 0
36
37    # If there are more than 2 distinct values, the coloring cannot be equitable
38    all_set_sizes = set(color_set_size.values())
39    if len(all_set_sizes) == 0 and num_colors is None:  # Was an empty graph
40        return True
41    elif len(all_set_sizes) == 1:
42        return True
43    elif len(all_set_sizes) == 2:
44        a, b = list(all_set_sizes)
45        return abs(a - b) <= 1
46    else:  # len(all_set_sizes) > 2:
47        return False
48
49
50def make_C_from_F(F):
51    C = defaultdict(lambda: [])
52    for node, color in F.items():
53        C[color].append(node)
54
55    return C
56
57
58def make_N_from_L_C(L, C):
59    nodes = L.keys()
60    colors = C.keys()
61    return {
62        (node, color): sum(1 for v in L[node] if v in C[color])
63        for node in nodes
64        for color in colors
65    }
66
67
68def make_H_from_C_N(C, N):
69    return {
70        (c1, c2): sum(1 for node in C[c1] if N[(node, c2)] == 0)
71        for c1 in C.keys()
72        for c2 in C.keys()
73    }
74
75
76def change_color(u, X, Y, N, H, F, C, L):
77    """Change the color of 'u' from X to Y and update N, H, F, C."""
78    assert F[u] == X and X != Y
79
80    # Change the class of 'u' from X to Y
81    F[u] = Y
82
83    for k in C.keys():
84        # 'u' witnesses an edge from k -> Y instead of from k -> X now.
85        if N[u, k] == 0:
86            H[(X, k)] -= 1
87            H[(Y, k)] += 1
88
89    for v in L[u]:
90        # 'v' has lost a neighbor in X and gained one in Y
91        N[(v, X)] -= 1
92        N[(v, Y)] += 1
93
94        if N[(v, X)] == 0:
95            # 'v' witnesses F[v] -> X
96            H[(F[v], X)] += 1
97
98        if N[(v, Y)] == 1:
99            # 'v' no longer witnesses F[v] -> Y
100            H[(F[v], Y)] -= 1
101
102    C[X].remove(u)
103    C[Y].append(u)
104
105
106def move_witnesses(src_color, dst_color, N, H, F, C, T_cal, L):
107    """Move witness along a path from src_color to dst_color."""
108    X = src_color
109    while X != dst_color:
110        Y = T_cal[X]
111        # Move _any_ witness from X to Y = T_cal[X]
112        w = [x for x in C[X] if N[(x, Y)] == 0][0]
113        change_color(w, X, Y, N=N, H=H, F=F, C=C, L=L)
114        X = Y
115
116
117def pad_graph(G, num_colors):
118    """Add a disconnected complete clique K_p such that the number of nodes in
119    the graph becomes a multiple of `num_colors`.
120
121    Assumes that the graph's nodes are labelled using integers.
122
123    Returns the number of nodes with each color.
124    """
125
126    n_ = len(G)
127    r = num_colors - 1
128
129    # Ensure that the number of nodes in G is a multiple of (r + 1)
130    s = n_ // (r + 1)
131    if n_ != s * (r + 1):
132        p = (r + 1) - n_ % (r + 1)
133        s += 1
134
135        # Complete graph K_p between (imaginary) nodes [n_, ... , n_ + p]
136        K = nx.relabel_nodes(nx.complete_graph(p), {idx: idx + n_ for idx in range(p)})
137        G.add_edges_from(K.edges)
138
139    return s
140
141
142def procedure_P(V_minus, V_plus, N, H, F, C, L, excluded_colors=None):
143    """Procedure P as described in the paper."""
144
145    if excluded_colors is None:
146        excluded_colors = set()
147
148    A_cal = set()
149    T_cal = {}
150    R_cal = []
151
152    # BFS to determine A_cal, i.e. colors reachable from V-
153    reachable = [V_minus]
154    marked = set(reachable)
155    idx = 0
156
157    while idx < len(reachable):
158        pop = reachable[idx]
159        idx += 1
160
161        A_cal.add(pop)
162        R_cal.append(pop)
163
164        # TODO: Checking whether a color has been visited can be made faster by
165        # using a look-up table instead of testing for membership in a set by a
166        # logarithmic factor.
167        next_layer = []
168        for k in C.keys():
169            if (
170                H[(k, pop)] > 0
171                and k not in A_cal
172                and k not in excluded_colors
173                and k not in marked
174            ):
175                next_layer.append(k)
176
177        for dst in next_layer:
178            # Record that `dst` can reach `pop`
179            T_cal[dst] = pop
180
181        marked.update(next_layer)
182        reachable.extend(next_layer)
183
184    # Variables for the algorithm
185    b = len(C) - len(A_cal)
186
187    if V_plus in A_cal:
188        # Easy case: V+ is in A_cal
189        # Move one node from V+ to V- using T_cal to find the parents.
190        move_witnesses(V_plus, V_minus, N=N, H=H, F=F, C=C, T_cal=T_cal, L=L)
191    else:
192        # If there is a solo edge, we can resolve the situation by
193        # moving witnesses from B to A, making G[A] equitable and then
194        # recursively balancing G[B - w] with a different V_minus and
195        # but the same V_plus.
196
197        A_0 = set()
198        A_cal_0 = set()
199        num_terminal_sets_found = 0
200        made_equitable = False
201
202        for W_1 in R_cal[::-1]:
203
204            for v in C[W_1]:
205                X = None
206
207                for U in C.keys():
208                    if N[(v, U)] == 0 and U in A_cal and U != W_1:
209                        X = U
210
211                # v does not witness an edge in H[A_cal]
212                if X is None:
213                    continue
214
215                for U in C.keys():
216                    # Note: Departing from the paper here.
217                    if N[(v, U)] >= 1 and U not in A_cal:
218                        X_prime = U
219                        w = v
220
221                        # Finding the solo neighbor of w in X_prime
222                        y_candidates = [
223                            node
224                            for node in L[w]
225                            if F[node] == X_prime and N[(node, W_1)] == 1
226                        ]
227
228                        if len(y_candidates) > 0:
229                            y = y_candidates[0]
230                            W = W_1
231
232                            # Move w from W to X, now X has one extra node.
233                            change_color(w, W, X, N=N, H=H, F=F, C=C, L=L)
234
235                            # Move witness from X to V_minus, making the coloring
236                            # equitable.
237                            move_witnesses(
238                                src_color=X,
239                                dst_color=V_minus,
240                                N=N,
241                                H=H,
242                                F=F,
243                                C=C,
244                                T_cal=T_cal,
245                                L=L,
246                            )
247
248                            # Move y from X_prime to W, making W the correct size.
249                            change_color(y, X_prime, W, N=N, H=H, F=F, C=C, L=L)
250
251                            # Then call the procedure on G[B - y]
252                            procedure_P(
253                                V_minus=X_prime,
254                                V_plus=V_plus,
255                                N=N,
256                                H=H,
257                                C=C,
258                                F=F,
259                                L=L,
260                                excluded_colors=excluded_colors.union(A_cal),
261                            )
262                            made_equitable = True
263                            break
264
265                if made_equitable:
266                    break
267            else:
268                # No node in W_1 was found such that
269                # it had a solo-neighbor.
270                A_cal_0.add(W_1)
271                A_0.update(C[W_1])
272                num_terminal_sets_found += 1
273
274            if num_terminal_sets_found == b:
275                # Otherwise, construct the maximal independent set and find
276                # a pair of z_1, z_2 as in Case II.
277
278                # BFS to determine B_cal': the set of colors reachable from V+
279                B_cal_prime = set()
280                T_cal_prime = {}
281
282                reachable = [V_plus]
283                marked = set(reachable)
284                idx = 0
285                while idx < len(reachable):
286                    pop = reachable[idx]
287                    idx += 1
288
289                    B_cal_prime.add(pop)
290
291                    # No need to check for excluded_colors here because
292                    # they only exclude colors from A_cal
293                    next_layer = [
294                        k
295                        for k in C.keys()
296                        if H[(pop, k)] > 0 and k not in B_cal_prime and k not in marked
297                    ]
298
299                    for dst in next_layer:
300                        T_cal_prime[pop] = dst
301
302                    marked.update(next_layer)
303                    reachable.extend(next_layer)
304
305                # Construct the independent set of G[B']
306                I_set = set()
307                I_covered = set()
308                W_covering = {}
309
310                B_prime = [node for k in B_cal_prime for node in C[k]]
311
312                # Add the nodes in V_plus to I first.
313                for z in C[V_plus] + B_prime:
314                    if z in I_covered or F[z] not in B_cal_prime:
315                        continue
316
317                    I_set.add(z)
318                    I_covered.add(z)
319                    I_covered.update([nbr for nbr in L[z]])
320
321                    for w in L[z]:
322                        if F[w] in A_cal_0 and N[(z, F[w])] == 1:
323                            if w not in W_covering:
324                                W_covering[w] = z
325                            else:
326                                # Found z1, z2 which have the same solo
327                                # neighbor in some W
328                                z_1 = W_covering[w]
329                                # z_2 = z
330
331                                Z = F[z_1]
332                                W = F[w]
333
334                                # shift nodes along W, V-
335                                move_witnesses(
336                                    W, V_minus, N=N, H=H, F=F, C=C, T_cal=T_cal, L=L
337                                )
338
339                                # shift nodes along V+ to Z
340                                move_witnesses(
341                                    V_plus,
342                                    Z,
343                                    N=N,
344                                    H=H,
345                                    F=F,
346                                    C=C,
347                                    T_cal=T_cal_prime,
348                                    L=L,
349                                )
350
351                                # change color of z_1 to W
352                                change_color(z_1, Z, W, N=N, H=H, F=F, C=C, L=L)
353
354                                # change color of w to some color in B_cal
355                                W_plus = [
356                                    k
357                                    for k in C.keys()
358                                    if N[(w, k)] == 0 and k not in A_cal
359                                ][0]
360                                change_color(w, W, W_plus, N=N, H=H, F=F, C=C, L=L)
361
362                                # recurse with G[B \cup W*]
363                                excluded_colors.update(
364                                    [
365                                        k
366                                        for k in C.keys()
367                                        if k != W and k not in B_cal_prime
368                                    ]
369                                )
370                                procedure_P(
371                                    V_minus=W,
372                                    V_plus=W_plus,
373                                    N=N,
374                                    H=H,
375                                    C=C,
376                                    F=F,
377                                    L=L,
378                                    excluded_colors=excluded_colors,
379                                )
380
381                                made_equitable = True
382                                break
383
384                    if made_equitable:
385                        break
386                else:
387                    assert False, (
388                        "Must find a w which is the solo neighbor "
389                        "of two vertices in B_cal_prime."
390                    )
391
392            if made_equitable:
393                break
394
395
396def equitable_color(G, num_colors):
397    """Provides equitable (r + 1)-coloring for nodes of G in O(r * n^2) time
398    if deg(G) <= r. The algorithm is described in [1]_.
399
400    Attempts to color a graph using r colors, where no neighbors of a node
401    can have same color as the node itself and the number of nodes with each
402    color differ by at most 1.
403
404    Parameters
405    ----------
406    G : networkX graph
407       The nodes of this graph will be colored.
408
409    num_colors : number of colors to use
410       This number must be at least one more than the maximum degree of nodes
411       in the graph.
412
413    Returns
414    -------
415    A dictionary with keys representing nodes and values representing
416    corresponding coloring.
417
418    Examples
419    --------
420    >>> G = nx.cycle_graph(4)
421    >>> d = nx.coloring.equitable_color(G, num_colors=3)
422    >>> nx.algorithms.coloring.equitable_coloring.is_equitable(G, d)
423    True
424
425    Raises
426    ------
427    NetworkXAlgorithmError
428        If the maximum degree of the graph ``G`` is greater than
429        ``num_colors``.
430
431    References
432    ----------
433    .. [1] Kierstead, H. A., Kostochka, A. V., Mydlarz, M., & Szemerédi, E.
434        (2010). A fast algorithm for equitable coloring. Combinatorica, 30(2),
435        217-224.
436    """
437
438    # Map nodes to integers for simplicity later.
439    nodes_to_int = {}
440    int_to_nodes = {}
441
442    for idx, node in enumerate(G.nodes):
443        nodes_to_int[node] = idx
444        int_to_nodes[idx] = node
445
446    G = nx.relabel_nodes(G, nodes_to_int, copy=True)
447
448    # Basic graph statistics and sanity check.
449    if len(G.nodes) > 0:
450        r_ = max([G.degree(node) for node in G.nodes])
451    else:
452        r_ = 0
453
454    if r_ >= num_colors:
455        raise nx.NetworkXAlgorithmError(
456            f"Graph has maximum degree {r_}, needs "
457            f"{r_ + 1} (> {num_colors}) colors for guaranteed coloring."
458        )
459
460    # Ensure that the number of nodes in G is a multiple of (r + 1)
461    pad_graph(G, num_colors)
462
463    # Starting the algorithm.
464    # L = {node: list(G.neighbors(node)) for node in G.nodes}
465    L_ = {node: [] for node in G.nodes}
466
467    # Arbitrary equitable allocation of colors to nodes.
468    F = {node: idx % num_colors for idx, node in enumerate(G.nodes)}
469
470    C = make_C_from_F(F)
471
472    # The neighborhood is empty initially.
473    N = make_N_from_L_C(L_, C)
474
475    # Currently all nodes witness all edges.
476    H = make_H_from_C_N(C, N)
477
478    # Start of algorithm.
479    edges_seen = set()
480
481    for u in sorted(G.nodes):
482        for v in sorted(G.neighbors(u)):
483
484            # Do not double count edges if (v, u) has already been seen.
485            if (v, u) in edges_seen:
486                continue
487
488            edges_seen.add((u, v))
489
490            L_[u].append(v)
491            L_[v].append(u)
492
493            N[(u, F[v])] += 1
494            N[(v, F[u])] += 1
495
496            if F[u] != F[v]:
497                # Were 'u' and 'v' witnesses for F[u] -> F[v] or F[v] -> F[u]?
498                if N[(u, F[v])] == 1:
499                    H[F[u], F[v]] -= 1  # u cannot witness an edge between F[u], F[v]
500
501                if N[(v, F[u])] == 1:
502                    H[F[v], F[u]] -= 1  # v cannot witness an edge between F[v], F[u]
503
504        if N[(u, F[u])] != 0:
505            # Find the first color where 'u' does not have any neighbors.
506            Y = [k for k in C.keys() if N[(u, k)] == 0][0]
507            X = F[u]
508            change_color(u, X, Y, N=N, H=H, F=F, C=C, L=L_)
509
510            # Procedure P
511            procedure_P(V_minus=X, V_plus=Y, N=N, H=H, F=F, C=C, L=L_)
512
513    return {int_to_nodes[x]: F[x] for x in int_to_nodes}
514