1"""
2Tools for dealing with a directed graph.
3
4"""
5import numpy as np
6from scipy import sparse
7from scipy.sparse import csgraph
8from math import gcd
9from numba import jit
10
11from .util import check_random_state
12
13
14# Decorator for *_components properties
15def annotate_nodes(func):
16    def new_func(self):
17        list_of_components = func(self)
18        if self.node_labels is not None:
19            return [self.node_labels[c] for c in list_of_components]
20        return list_of_components
21    return new_func
22
23
24class DiGraph:
25    r"""
26    Class for a directed graph. It stores useful information about the
27    graph structure such as strong connectivity [1]_ and periodicity
28    [2]_.
29
30    Parameters
31    ----------
32    adj_matrix : array_like(ndim=2)
33        Adjacency matrix representing a directed graph. Must be of shape
34        n x n.
35
36    weighted : bool, optional(default=False)
37        Whether to treat `adj_matrix` as a weighted adjacency matrix.
38
39    node_labels : array_like(default=None)
40        Array_like of length n containing the labels associated with the
41        nodes, which must be homogeneous in type. If None, the labels
42        default to integers 0 through n-1.
43
44    Attributes
45    ----------
46    csgraph : scipy.sparse.csr_matrix
47        Compressed sparse representation of the digraph.
48
49    is_strongly_connected : bool
50        Indicate whether the digraph is strongly connected.
51
52    num_strongly_connected_components : int
53        The number of the strongly connected components.
54
55    strongly_connected_components_indices : list(ndarray(int))
56        List of numpy arrays containing the indices of the strongly
57        connected components.
58
59    strongly_connected_components : list(ndarray)
60        List of numpy arrays containing the strongly connected
61        components, where the nodes are annotated with their labels (if
62        `node_labels` is not None).
63
64    num_sink_strongly_connected_components : int
65        The number of the sink strongly connected components.
66
67    sink_strongly_connected_components_indices : list(ndarray(int))
68        List of numpy arrays containing the indices of the sink strongly
69        connected components.
70
71    sink_strongly_connected_components : list(ndarray)
72        List of numpy arrays containing the sink strongly connected
73        components, where the nodes are annotated with their labels (if
74        `node_labels` is not None).
75
76    is_aperiodic : bool
77        Indicate whether the digraph is aperiodic.
78
79    period : int
80        The period of the digraph. Defined only for a strongly connected
81        digraph.
82
83    cyclic_components_indices : list(ndarray(int))
84        List of numpy arrays containing the indices of the cyclic
85        components.
86
87    cyclic_components : list(ndarray)
88        List of numpy arrays containing the cyclic components, where the
89        nodes are annotated with their labels (if `node_labels` is not
90        None).
91
92    References
93    ----------
94    .. [1] `Strongly connected component
95       <http://en.wikipedia.org/wiki/Strongly_connected_component>`_,
96       Wikipedia.
97
98    .. [2] `Aperiodic graph
99       <http://en.wikipedia.org/wiki/Aperiodic_graph>`_, Wikipedia.
100
101    """
102
103    def __init__(self, adj_matrix, weighted=False, node_labels=None):
104        if weighted:
105            dtype = None
106        else:
107            dtype = bool
108        self.csgraph = sparse.csr_matrix(adj_matrix, dtype=dtype)
109
110        m, n = self.csgraph.shape
111        if n != m:
112            raise ValueError('input matrix must be square')
113
114        self.n = n  # Number of nodes
115
116        # Call the setter method
117        self.node_labels = node_labels
118
119        self._num_scc = None
120        self._scc_proj = None
121        self._sink_scc_labels = None
122
123        self._period = None
124
125    def __repr__(self):
126        return self.__str__()
127
128    def __str__(self):
129        return "Directed Graph:\n  - n(number of nodes): {n}".format(n=self.n)
130
131    @property
132    def node_labels(self):
133        return self._node_labels
134
135    @node_labels.setter
136    def node_labels(self, values):
137        if values is None:
138            self._node_labels = None
139        else:
140            values = np.asarray(values)
141            if (values.ndim < 1) or (values.shape[0] != self.n):
142                raise ValueError(
143                    'node_labels must be an array_like of length n'
144                )
145            if np.issubdtype(values.dtype, np.object_):
146                raise ValueError(
147                    'data in node_labels must be homogeneous in type'
148                )
149            self._node_labels = values
150
151    def _find_scc(self):
152        """
153        Set ``self._num_scc`` and ``self._scc_proj``
154        by calling ``scipy.sparse.csgraph.connected_components``:
155        * docs.scipy.org/doc/scipy/reference/sparse.csgraph.html
156        * github.com/scipy/scipy/blob/master/scipy/sparse/csgraph/_traversal.pyx
157
158        ``self._scc_proj`` is a list of length `n` that assigns to each node
159        the label of the strongly connected component to which it belongs.
160
161        """
162        # Find the strongly connected components
163        self._num_scc, self._scc_proj = \
164            csgraph.connected_components(self.csgraph, connection='strong')
165
166    @property
167    def num_strongly_connected_components(self):
168        if self._num_scc is None:
169            self._find_scc()
170        return self._num_scc
171
172    @property
173    def scc_proj(self):
174        if self._scc_proj is None:
175            self._find_scc()
176        return self._scc_proj
177
178    @property
179    def is_strongly_connected(self):
180        return (self.num_strongly_connected_components == 1)
181
182    def _condensation_lil(self):
183        """
184        Return the sparse matrix representation of the condensation digraph
185        in lil format.
186
187        """
188        condensation_lil = sparse.lil_matrix(
189            (self.num_strongly_connected_components,
190             self.num_strongly_connected_components), dtype=bool
191        )
192
193        scc_proj = self.scc_proj
194        for node_from, node_to in _csr_matrix_indices(self.csgraph):
195            scc_from, scc_to = scc_proj[node_from], scc_proj[node_to]
196            if scc_from != scc_to:
197                condensation_lil[scc_from, scc_to] = True
198
199        return condensation_lil
200
201    def _find_sink_scc(self):
202        """
203        Set self._sink_scc_labels, which is a list containing the labels of
204        the strongly connected components.
205
206        """
207        condensation_lil = self._condensation_lil()
208
209        # A sink SCC is a SCC such that none of its members is strongly
210        # connected to nodes in other SCCs
211        # Those k's such that graph_condensed_lil.rows[k] == []
212        self._sink_scc_labels = \
213            np.where(np.logical_not(condensation_lil.rows))[0]
214
215    @property
216    def sink_scc_labels(self):
217        if self._sink_scc_labels is None:
218            self._find_sink_scc()
219        return self._sink_scc_labels
220
221    @property
222    def num_sink_strongly_connected_components(self):
223        return len(self.sink_scc_labels)
224
225    @property
226    def strongly_connected_components_indices(self):
227        if self.is_strongly_connected:
228            return [np.arange(self.n)]
229        else:
230            return [np.where(self.scc_proj == k)[0]
231                    for k in range(self.num_strongly_connected_components)]
232
233    @property
234    @annotate_nodes
235    def strongly_connected_components(self):
236        return self.strongly_connected_components_indices
237
238    @property
239    def sink_strongly_connected_components_indices(self):
240        if self.is_strongly_connected:
241            return [np.arange(self.n)]
242        else:
243            return [np.where(self.scc_proj == k)[0]
244                    for k in self.sink_scc_labels.tolist()]
245
246    @property
247    @annotate_nodes
248    def sink_strongly_connected_components(self):
249        return self.sink_strongly_connected_components_indices
250
251    def _compute_period(self):
252        """
253        Set ``self._period`` and ``self._cyclic_components_proj``.
254
255        Use the algorithm described in:
256        J. P. Jarvis and D. R. Shier,
257        "Graph-Theoretic Analysis of Finite Markov Chains," 1996.
258
259        """
260        # Degenerate graph with a single node (which is strongly connected)
261        # csgraph.reconstruct_path would raise an exception
262        # github.com/scipy/scipy/issues/4018
263        if self.n == 1:
264            if self.csgraph[0, 0] == 0:  # No edge: "trivial graph"
265                self._period = 1  # Any universally accepted definition?
266                self._cyclic_components_proj = np.zeros(self.n, dtype=int)
267                return None
268            else:  # Self loop
269                self._period = 1
270                self._cyclic_components_proj = np.zeros(self.n, dtype=int)
271                return None
272
273        if not self.is_strongly_connected:
274            raise NotImplementedError(
275                'Not defined for a non strongly-connected digraph'
276            )
277
278        if np.any(self.csgraph.diagonal() > 0):
279            self._period = 1
280            self._cyclic_components_proj = np.zeros(self.n, dtype=int)
281            return None
282
283        # Construct a breadth-first search tree rooted at 0
284        node_order, predecessors = \
285            csgraph.breadth_first_order(self.csgraph, i_start=0)
286        bfs_tree_csr = \
287            csgraph.reconstruct_path(self.csgraph, predecessors)
288
289        # Edges not belonging to tree_csr
290        non_bfs_tree_csr = self.csgraph - bfs_tree_csr
291        non_bfs_tree_csr.eliminate_zeros()
292
293        # Distance to 0
294        level = np.zeros(self.n, dtype=int)
295        for i in range(1, self.n):
296            level[node_order[i]] = level[predecessors[node_order[i]]] + 1
297
298        # Determine the period
299        d = 0
300        for node_from, node_to in _csr_matrix_indices(non_bfs_tree_csr):
301            value = level[node_from] - level[node_to] + 1
302            d = gcd(d, value)
303            if d == 1:
304                self._period = 1
305                self._cyclic_components_proj = np.zeros(self.n, dtype=int)
306                return None
307
308        self._period = d
309        self._cyclic_components_proj = level % d
310
311    @property
312    def period(self):
313        if self._period is None:
314            self._compute_period()
315        return self._period
316
317    @property
318    def is_aperiodic(self):
319        return (self.period == 1)
320
321    @property
322    def cyclic_components_indices(self):
323        if self.is_aperiodic:
324            return [np.arange(self.n)]
325        else:
326            return [np.where(self._cyclic_components_proj == k)[0]
327                    for k in range(self.period)]
328
329    @property
330    @annotate_nodes
331    def cyclic_components(self,):
332        return self.cyclic_components_indices
333
334    def subgraph(self, nodes):
335        """
336        Return the subgraph consisting of the given nodes and edges
337        between thses nodes.
338
339        Parameters
340        ----------
341        nodes : array_like(int, ndim=1)
342           Array of node indices.
343
344        Returns
345        -------
346        DiGraph
347            A DiGraph representing the subgraph.
348
349        """
350        adj_matrix = self.csgraph[np.ix_(nodes, nodes)]
351
352        weighted = True  # To copy the dtype
353
354        if self.node_labels is not None:
355            node_labels = self.node_labels[nodes]
356        else:
357            node_labels = None
358
359        return DiGraph(adj_matrix, weighted=weighted, node_labels=node_labels)
360
361
362def _csr_matrix_indices(S):
363    """
364    Generate the indices of nonzero entries of a csr_matrix S
365
366    """
367    m, n = S.shape
368
369    for i in range(m):
370        for j in range(S.indptr[i], S.indptr[i+1]):
371            row_index, col_index = i, S.indices[j]
372            yield row_index, col_index
373
374
375def random_tournament_graph(n, random_state=None):
376    """
377    Return a random tournament graph [1]_ with n nodes.
378
379    Parameters
380    ----------
381    n : scalar(int)
382        Number of nodes.
383
384    random_state : int or np.random.RandomState, optional
385        Random seed (integer) or np.random.RandomState instance to set
386        the initial state of the random number generator for
387        reproducibility. If None, a randomly initialized RandomState is
388        used.
389
390    Returns
391    -------
392    DiGraph
393        A DiGraph representing the tournament graph.
394
395    References
396    ----------
397    .. [1] `Tournament (graph theory)
398       <https://en.wikipedia.org/wiki/Tournament_(graph_theory)>`_,
399       Wikipedia.
400
401    """
402    random_state = check_random_state(random_state)
403    num_edges = n * (n-1) // 2
404    r = random_state.random_sample(num_edges)
405    row = np.empty(num_edges, dtype=int)
406    col = np.empty(num_edges, dtype=int)
407    _populate_random_tournament_row_col(n, r, row, col)
408    data = np.ones(num_edges, dtype=bool)
409    adj_matrix = sparse.coo_matrix((data, (row, col)), shape=(n, n))
410    return DiGraph(adj_matrix)
411
412
413@jit(nopython=True, cache=True)
414def _populate_random_tournament_row_col(n, r, row, col):
415    """
416    Populate ndarrays `row` and `col` with directed edge indices
417    determined by random numbers in `r` for a tournament graph with n
418    nodes, which has num_edges = n * (n-1) // 2 edges.
419
420    Parameters
421    ----------
422    n : scalar(int)
423        Number of nodes.
424
425    r : ndarray(float, ndim=1)
426        ndarray of length num_edges containing random numbers in [0, 1).
427
428    row, col : ndarray(int, ndim=1)
429        ndarrays of length num_edges to be modified in place.
430
431    """
432    k = 0
433    for i in range(n):
434        for j in range(i+1, n):
435            if r[k] < 0.5:
436                row[k], col[k] = i, j
437            else:
438                row[k], col[k] = j, i
439            k += 1
440