1"""
2**********
3Matplotlib
4**********
5
6Draw networks with matplotlib.
7
8Examples
9--------
10>>> G = nx.complete_graph(5)
11>>> nx.draw(G)
12
13See Also
14--------
15 - :doc:`matplotlib <matplotlib:index>`
16 - :func:`matplotlib.pyplot.scatter`
17 - :obj:`matplotlib.patches.FancyArrowPatch`
18"""
19from numbers import Number
20import networkx as nx
21from networkx.drawing.layout import (
22    shell_layout,
23    circular_layout,
24    kamada_kawai_layout,
25    spectral_layout,
26    spring_layout,
27    random_layout,
28    planar_layout,
29)
30import warnings
31
32__all__ = [
33    "draw",
34    "draw_networkx",
35    "draw_networkx_nodes",
36    "draw_networkx_edges",
37    "draw_networkx_labels",
38    "draw_networkx_edge_labels",
39    "draw_circular",
40    "draw_kamada_kawai",
41    "draw_random",
42    "draw_spectral",
43    "draw_spring",
44    "draw_planar",
45    "draw_shell",
46]
47
48
49def draw(G, pos=None, ax=None, **kwds):
50    """Draw the graph G with Matplotlib.
51
52    Draw the graph as a simple representation with no node
53    labels or edge labels and using the full Matplotlib figure area
54    and no axis labels by default.  See draw_networkx() for more
55    full-featured drawing that allows title, axis labels etc.
56
57    Parameters
58    ----------
59    G : graph
60        A networkx graph
61
62    pos : dictionary, optional
63        A dictionary with nodes as keys and positions as values.
64        If not specified a spring layout positioning will be computed.
65        See :py:mod:`networkx.drawing.layout` for functions that
66        compute node positions.
67
68    ax : Matplotlib Axes object, optional
69        Draw the graph in specified Matplotlib axes.
70
71    kwds : optional keywords
72        See networkx.draw_networkx() for a description of optional keywords.
73
74    Examples
75    --------
76    >>> G = nx.dodecahedral_graph()
77    >>> nx.draw(G)
78    >>> nx.draw(G, pos=nx.spring_layout(G))  # use spring layout
79
80    See Also
81    --------
82    draw_networkx
83    draw_networkx_nodes
84    draw_networkx_edges
85    draw_networkx_labels
86    draw_networkx_edge_labels
87
88    Notes
89    -----
90    This function has the same name as pylab.draw and pyplot.draw
91    so beware when using `from networkx import *`
92
93    since you might overwrite the pylab.draw function.
94
95    With pyplot use
96
97    >>> import matplotlib.pyplot as plt
98    >>> G = nx.dodecahedral_graph()
99    >>> nx.draw(G)  # networkx draw()
100    >>> plt.draw()  # pyplot draw()
101
102    Also see the NetworkX drawing examples at
103    https://networkx.org/documentation/latest/auto_examples/index.html
104    """
105    import matplotlib.pyplot as plt
106
107    if ax is None:
108        cf = plt.gcf()
109    else:
110        cf = ax.get_figure()
111    cf.set_facecolor("w")
112    if ax is None:
113        if cf._axstack() is None:
114            ax = cf.add_axes((0, 0, 1, 1))
115        else:
116            ax = cf.gca()
117
118    if "with_labels" not in kwds:
119        kwds["with_labels"] = "labels" in kwds
120
121    draw_networkx(G, pos=pos, ax=ax, **kwds)
122    ax.set_axis_off()
123    plt.draw_if_interactive()
124    return
125
126
127def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
128    r"""Draw the graph G using Matplotlib.
129
130    Draw the graph with Matplotlib with options for node positions,
131    labeling, titles, and many other drawing features.
132    See draw() for simple drawing without labels or axes.
133
134    Parameters
135    ----------
136    G : graph
137        A networkx graph
138
139    pos : dictionary, optional
140        A dictionary with nodes as keys and positions as values.
141        If not specified a spring layout positioning will be computed.
142        See :py:mod:`networkx.drawing.layout` for functions that
143        compute node positions.
144
145    arrows : bool or None, optional (default=None)
146        If `None`, directed graphs draw arrowheads with
147        `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
148        via `~matplotlib.collections.LineCollection` for speed.
149        If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
150        If `False`, draw edges using LineCollection (linear and fast).
151        For directed graphs, if True draw arrowheads.
152        Note: Arrows will be the same color as edges.
153
154    arrowstyle : str (default='-\|>')
155        For directed graphs, choose the style of the arrowsheads.
156        See `matplotlib.patches.ArrowStyle` for more options.
157
158    arrowsize : int (default=10)
159        For directed graphs, choose the size of the arrow head's length and
160        width. See `matplotlib.patches.FancyArrowPatch` for attribute
161        `mutation_scale` for more info.
162
163    with_labels :  bool (default=True)
164        Set to True to draw labels on the nodes.
165
166    ax : Matplotlib Axes object, optional
167        Draw the graph in the specified Matplotlib axes.
168
169    nodelist : list (default=list(G))
170        Draw only specified nodes
171
172    edgelist : list (default=list(G.edges()))
173        Draw only specified edges
174
175    node_size : scalar or array (default=300)
176        Size of nodes.  If an array is specified it must be the
177        same length as nodelist.
178
179    node_color : color or array of colors (default='#1f78b4')
180        Node color. Can be a single color or a sequence of colors with the same
181        length as nodelist. Color can be string or rgb (or rgba) tuple of
182        floats from 0-1. If numeric values are specified they will be
183        mapped to colors using the cmap and vmin,vmax parameters. See
184        matplotlib.scatter for more details.
185
186    node_shape :  string (default='o')
187        The shape of the node.  Specification is as matplotlib.scatter
188        marker, one of 'so^>v<dph8'.
189
190    alpha : float or None (default=None)
191        The node and edge transparency
192
193    cmap : Matplotlib colormap, optional
194        Colormap for mapping intensities of nodes
195
196    vmin,vmax : float, optional
197        Minimum and maximum for node colormap scaling
198
199    linewidths : scalar or sequence (default=1.0)
200        Line width of symbol border
201
202    width : float or array of floats (default=1.0)
203        Line width of edges
204
205    edge_color : color or array of colors (default='k')
206        Edge color. Can be a single color or a sequence of colors with the same
207        length as edgelist. Color can be string or rgb (or rgba) tuple of
208        floats from 0-1. If numeric values are specified they will be
209        mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
210
211    edge_cmap : Matplotlib colormap, optional
212        Colormap for mapping intensities of edges
213
214    edge_vmin,edge_vmax : floats, optional
215        Minimum and maximum for edge colormap scaling
216
217    style : string (default=solid line)
218        Edge line style e.g.: '-', '--', '-.', ':'
219        or words like 'solid' or 'dashed'.
220        (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
221
222    labels : dictionary (default=None)
223        Node labels in a dictionary of text labels keyed by node
224
225    font_size : int (default=12 for nodes, 10 for edges)
226        Font size for text labels
227
228    font_color : string (default='k' black)
229        Font color string
230
231    font_weight : string (default='normal')
232        Font weight
233
234    font_family : string (default='sans-serif')
235        Font family
236
237    label : string, optional
238        Label for graph legend
239
240    kwds : optional keywords
241        See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and
242        networkx.draw_networkx_labels() for a description of optional keywords.
243
244    Notes
245    -----
246    For directed graphs, arrows  are drawn at the head end.  Arrows can be
247    turned off with keyword arrows=False.
248
249    Examples
250    --------
251    >>> G = nx.dodecahedral_graph()
252    >>> nx.draw(G)
253    >>> nx.draw(G, pos=nx.spring_layout(G))  # use spring layout
254
255    >>> import matplotlib.pyplot as plt
256    >>> limits = plt.axis("off")  # turn off axis
257
258    Also see the NetworkX drawing examples at
259    https://networkx.org/documentation/latest/auto_examples/index.html
260
261    See Also
262    --------
263    draw
264    draw_networkx_nodes
265    draw_networkx_edges
266    draw_networkx_labels
267    draw_networkx_edge_labels
268    """
269    import matplotlib.pyplot as plt
270
271    valid_node_kwds = (
272        "nodelist",
273        "node_size",
274        "node_color",
275        "node_shape",
276        "alpha",
277        "cmap",
278        "vmin",
279        "vmax",
280        "ax",
281        "linewidths",
282        "edgecolors",
283        "label",
284    )
285
286    valid_edge_kwds = (
287        "edgelist",
288        "width",
289        "edge_color",
290        "style",
291        "alpha",
292        "arrowstyle",
293        "arrowsize",
294        "edge_cmap",
295        "edge_vmin",
296        "edge_vmax",
297        "ax",
298        "label",
299        "node_size",
300        "nodelist",
301        "node_shape",
302        "connectionstyle",
303        "min_source_margin",
304        "min_target_margin",
305    )
306
307    valid_label_kwds = (
308        "labels",
309        "font_size",
310        "font_color",
311        "font_family",
312        "font_weight",
313        "alpha",
314        "bbox",
315        "ax",
316        "horizontalalignment",
317        "verticalalignment",
318    )
319
320    valid_kwds = valid_node_kwds + valid_edge_kwds + valid_label_kwds
321
322    if any([k not in valid_kwds for k in kwds]):
323        invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
324        raise ValueError(f"Received invalid argument(s): {invalid_args}")
325
326    node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
327    edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
328    label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
329
330    if pos is None:
331        pos = nx.drawing.spring_layout(G)  # default to spring layout
332
333    draw_networkx_nodes(G, pos, **node_kwds)
334    draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
335    if with_labels:
336        draw_networkx_labels(G, pos, **label_kwds)
337    plt.draw_if_interactive()
338
339
340def draw_networkx_nodes(
341    G,
342    pos,
343    nodelist=None,
344    node_size=300,
345    node_color="#1f78b4",
346    node_shape="o",
347    alpha=None,
348    cmap=None,
349    vmin=None,
350    vmax=None,
351    ax=None,
352    linewidths=None,
353    edgecolors=None,
354    label=None,
355    margins=None,
356):
357    """Draw the nodes of the graph G.
358
359    This draws only the nodes of the graph G.
360
361    Parameters
362    ----------
363    G : graph
364        A networkx graph
365
366    pos : dictionary
367        A dictionary with nodes as keys and positions as values.
368        Positions should be sequences of length 2.
369
370    ax : Matplotlib Axes object, optional
371        Draw the graph in the specified Matplotlib axes.
372
373    nodelist : list (default list(G))
374        Draw only specified nodes
375
376    node_size : scalar or array (default=300)
377        Size of nodes.  If an array it must be the same length as nodelist.
378
379    node_color : color or array of colors (default='#1f78b4')
380        Node color. Can be a single color or a sequence of colors with the same
381        length as nodelist. Color can be string or rgb (or rgba) tuple of
382        floats from 0-1. If numeric values are specified they will be
383        mapped to colors using the cmap and vmin,vmax parameters. See
384        matplotlib.scatter for more details.
385
386    node_shape :  string (default='o')
387        The shape of the node.  Specification is as matplotlib.scatter
388        marker, one of 'so^>v<dph8'.
389
390    alpha : float or array of floats (default=None)
391        The node transparency.  This can be a single alpha value,
392        in which case it will be applied to all the nodes of color. Otherwise,
393        if it is an array, the elements of alpha will be applied to the colors
394        in order (cycling through alpha multiple times if necessary).
395
396    cmap : Matplotlib colormap (default=None)
397        Colormap for mapping intensities of nodes
398
399    vmin,vmax : floats or None (default=None)
400        Minimum and maximum for node colormap scaling
401
402    linewidths : [None | scalar | sequence] (default=1.0)
403        Line width of symbol border
404
405    edgecolors : [None | scalar | sequence] (default = node_color)
406        Colors of node borders
407
408    label : [None | string]
409        Label for legend
410
411    margins : float or 2-tuple, optional
412        Sets the padding for axis autoscaling. Increase margin to prevent
413        clipping for nodes that are near the edges of an image. Values should
414        be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins`
415        for details. The default is `None`, which uses the Matplotlib default.
416
417    Returns
418    -------
419    matplotlib.collections.PathCollection
420        `PathCollection` of the nodes.
421
422    Examples
423    --------
424    >>> G = nx.dodecahedral_graph()
425    >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
426
427    Also see the NetworkX drawing examples at
428    https://networkx.org/documentation/latest/auto_examples/index.html
429
430    See Also
431    --------
432    draw
433    draw_networkx
434    draw_networkx_edges
435    draw_networkx_labels
436    draw_networkx_edge_labels
437    """
438    from collections.abc import Iterable
439    import numpy as np
440    import matplotlib as mpl
441    import matplotlib.collections  # call as mpl.collections
442    import matplotlib.pyplot as plt
443
444    if ax is None:
445        ax = plt.gca()
446
447    if nodelist is None:
448        nodelist = list(G)
449
450    if len(nodelist) == 0:  # empty nodelist, no drawing
451        return mpl.collections.PathCollection(None)
452
453    try:
454        xy = np.asarray([pos[v] for v in nodelist])
455    except KeyError as e:
456        raise nx.NetworkXError(f"Node {e} has no position.") from e
457
458    if isinstance(alpha, Iterable):
459        node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
460        alpha = None
461
462    node_collection = ax.scatter(
463        xy[:, 0],
464        xy[:, 1],
465        s=node_size,
466        c=node_color,
467        marker=node_shape,
468        cmap=cmap,
469        vmin=vmin,
470        vmax=vmax,
471        alpha=alpha,
472        linewidths=linewidths,
473        edgecolors=edgecolors,
474        label=label,
475    )
476    ax.tick_params(
477        axis="both",
478        which="both",
479        bottom=False,
480        left=False,
481        labelbottom=False,
482        labelleft=False,
483    )
484
485    if margins is not None:
486        if isinstance(margins, Iterable):
487            ax.margins(*margins)
488        else:
489            ax.margins(margins)
490
491    node_collection.set_zorder(2)
492    return node_collection
493
494
495def draw_networkx_edges(
496    G,
497    pos,
498    edgelist=None,
499    width=1.0,
500    edge_color="k",
501    style="solid",
502    alpha=None,
503    arrowstyle="-|>",
504    arrowsize=10,
505    edge_cmap=None,
506    edge_vmin=None,
507    edge_vmax=None,
508    ax=None,
509    arrows=None,
510    label=None,
511    node_size=300,
512    nodelist=None,
513    node_shape="o",
514    connectionstyle="arc3",
515    min_source_margin=0,
516    min_target_margin=0,
517):
518    r"""Draw the edges of the graph G.
519
520    This draws only the edges of the graph G.
521
522    Parameters
523    ----------
524    G : graph
525        A networkx graph
526
527    pos : dictionary
528        A dictionary with nodes as keys and positions as values.
529        Positions should be sequences of length 2.
530
531    edgelist : collection of edge tuples (default=G.edges())
532        Draw only specified edges
533
534    width : float or array of floats (default=1.0)
535        Line width of edges
536
537    edge_color : color or array of colors (default='k')
538        Edge color. Can be a single color or a sequence of colors with the same
539        length as edgelist. Color can be string or rgb (or rgba) tuple of
540        floats from 0-1. If numeric values are specified they will be
541        mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
542
543    style : string (default=solid line)
544        Edge line style e.g.: '-', '--', '-.', ':'
545        or words like 'solid' or 'dashed'.
546        (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
547
548    alpha : float or None (default=None)
549        The edge transparency
550
551    edge_cmap : Matplotlib colormap, optional
552        Colormap for mapping intensities of edges
553
554    edge_vmin,edge_vmax : floats, optional
555        Minimum and maximum for edge colormap scaling
556
557    ax : Matplotlib Axes object, optional
558        Draw the graph in the specified Matplotlib axes.
559
560    arrows : bool or None, optional (default=None)
561        If `None`, directed graphs draw arrowheads with
562        `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
563        via `~matplotlib.collections.LineCollection` for speed.
564        If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
565        If `False`, draw edges using LineCollection (linear and fast).
566
567        Note: Arrowheads will be the same color as edges.
568
569    arrowstyle : str (default='-\|>')
570        For directed graphs and `arrows==True` defaults to '-\|>',
571
572        See `matplotlib.patches.ArrowStyle` for more options.
573
574    arrowsize : int (default=10)
575        For directed graphs, choose the size of the arrow head's length and
576        width. See `matplotlib.patches.FancyArrowPatch` for attribute
577        `mutation_scale` for more info.
578
579    connectionstyle : string (default="arc3")
580        Pass the connectionstyle parameter to create curved arc of rounding
581        radius rad. For example, connectionstyle='arc3,rad=0.2'.
582        See `matplotlib.patches.ConnectionStyle` and
583        `matplotlib.patches.FancyArrowPatch` for more info.
584
585    node_size : scalar or array (default=300)
586        Size of nodes. Though the nodes are not drawn with this function, the
587        node size is used in determining edge positioning.
588
589    nodelist : list, optional (default=G.nodes())
590       This provides the node order for the `node_size` array (if it is an array).
591
592    node_shape :  string (default='o')
593        The marker used for nodes, used in determining edge positioning.
594        Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'.
595
596    label : None or string
597        Label for legend
598
599    min_source_margin : int (default=0)
600        The minimum margin (gap) at the begining of the edge at the source.
601
602    min_target_margin : int (default=0)
603        The minimum margin (gap) at the end of the edge at the target.
604
605    Returns
606    -------
607     matplotlib.colections.LineCollection or a list of matplotlib.patches.FancyArrowPatch
608        If ``arrows=True``, a list of FancyArrowPatches is returned.
609        If ``arrows=False``, a LineCollection is returned.
610        If ``arrows=None`` (the default), then a LineCollection is returned if
611        `G` is undirected, otherwise returns a list of FancyArrowPatches.
612
613    Notes
614    -----
615    For directed graphs, arrows are drawn at the head end.  Arrows can be
616    turned off with keyword arrows=False or by passing an arrowstyle without
617    an arrow on the end.
618
619    Be sure to include `node_size` as a keyword argument; arrows are
620    drawn considering the size of nodes.
621
622    Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch`
623    regardless of the value of `arrows` or whether `G` is directed.
624    When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the
625    FancyArrowPatches corresponding to the self-loops are not explicitly
626    returned. They should instead be accessed via the ``Axes.patches``
627    attribute (see examples).
628
629    Examples
630    --------
631    >>> G = nx.dodecahedral_graph()
632    >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
633
634    >>> G = nx.DiGraph()
635    >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
636    >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
637    >>> alphas = [0.3, 0.4, 0.5]
638    >>> for i, arc in enumerate(arcs):  # change alpha values of arcs
639    ...     arc.set_alpha(alphas[i])
640
641    The FancyArrowPatches corresponding to self-loops are not always
642    returned, but can always be accessed via the ``patches`` attribute of the
643    `matplotlib.Axes` object.
644
645    >>> import matplotlib.pyplot as plt
646    >>> fig, ax = plt.subplots()
647    >>> G = nx.Graph([(0, 1), (0, 0)])  # Self-loop at node 0
648    >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax)
649    >>> self_loop_fap = ax.patches[0]
650
651    Also see the NetworkX drawing examples at
652    https://networkx.org/documentation/latest/auto_examples/index.html
653
654    See Also
655    --------
656    draw
657    draw_networkx
658    draw_networkx_nodes
659    draw_networkx_labels
660    draw_networkx_edge_labels
661
662    """
663    import numpy as np
664    import matplotlib as mpl
665    import matplotlib.colors  # call as mpl.colors
666    import matplotlib.patches  # call as mpl.patches
667    import matplotlib.collections  # call as mpl.collections
668    import matplotlib.path  # call as mpl.path
669    import matplotlib.pyplot as plt
670
671    # The default behavior is to use LineCollection to draw edges for
672    # undirected graphs (for performance reasons) and use FancyArrowPatches
673    # for directed graphs.
674    # The `arrows` keyword can be used to override the default behavior
675    use_linecollection = not G.is_directed()
676    if arrows in (True, False):
677        use_linecollection = not arrows
678
679    if ax is None:
680        ax = plt.gca()
681
682    if edgelist is None:
683        edgelist = list(G.edges())
684
685    if len(edgelist) == 0:  # no edges!
686        return []
687
688    if nodelist is None:
689        nodelist = list(G.nodes())
690
691    # FancyArrowPatch handles color=None different from LineCollection
692    if edge_color is None:
693        edge_color = "k"
694
695    # set edge positions
696    edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
697
698    # Check if edge_color is an array of floats and map to edge_cmap.
699    # This is the only case handled differently from matplotlib
700    if (
701        np.iterable(edge_color)
702        and (len(edge_color) == len(edge_pos))
703        and np.alltrue([isinstance(c, Number) for c in edge_color])
704    ):
705        if edge_cmap is not None:
706            assert isinstance(edge_cmap, mpl.colors.Colormap)
707        else:
708            edge_cmap = plt.get_cmap()
709        if edge_vmin is None:
710            edge_vmin = min(edge_color)
711        if edge_vmax is None:
712            edge_vmax = max(edge_color)
713        color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
714        edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
715
716    def _draw_networkx_edges_line_collection():
717        edge_collection = mpl.collections.LineCollection(
718            edge_pos,
719            colors=edge_color,
720            linewidths=width,
721            antialiaseds=(1,),
722            linestyle=style,
723            transOffset=ax.transData,
724            alpha=alpha,
725        )
726        edge_collection.set_cmap(edge_cmap)
727        edge_collection.set_clim(edge_vmin, edge_vmax)
728        edge_collection.set_zorder(1)  # edges go behind nodes
729        edge_collection.set_label(label)
730        ax.add_collection(edge_collection)
731
732        return edge_collection
733
734    def _draw_networkx_edges_fancy_arrow_patch():
735        # Note: Waiting for someone to implement arrow to intersection with
736        # marker.  Meanwhile, this works well for polygons with more than 4
737        # sides and circle.
738
739        def to_marker_edge(marker_size, marker):
740            if marker in "s^>v<d":  # `large` markers need extra space
741                return np.sqrt(2 * marker_size) / 2
742            else:
743                return np.sqrt(marker_size) / 2
744
745        # Draw arrows with `matplotlib.patches.FancyarrowPatch`
746        arrow_collection = []
747        mutation_scale = arrowsize  # scale factor of arrow head
748
749        base_connection_style = mpl.patches.ConnectionStyle(connectionstyle)
750
751        # Fallback for self-loop scale. Left outside of _connectionstyle so it is
752        # only computed once
753        max_nodesize = np.array(node_size).max()
754
755        def _connectionstyle(posA, posB, *args, **kwargs):
756            # check if we need to do a self-loop
757            if np.all(posA == posB):
758                # Self-loops are scaled by view extent, except in cases the extent
759                # is 0, e.g. for a single node. In this case, fall back to scaling
760                # by the maximum node size
761                selfloop_ht = 0.005 * max_nodesize if h == 0 else h
762                # this is called with _screen space_ values so covert back
763                # to data space
764                data_loc = ax.transData.inverted().transform(posA)
765                v_shift = 0.1 * selfloop_ht
766                h_shift = v_shift * 0.5
767                # put the top of the loop first so arrow is not hidden by node
768                path = [
769                    # 1
770                    data_loc + np.asarray([0, v_shift]),
771                    # 4 4 4
772                    data_loc + np.asarray([h_shift, v_shift]),
773                    data_loc + np.asarray([h_shift, 0]),
774                    data_loc,
775                    # 4 4 4
776                    data_loc + np.asarray([-h_shift, 0]),
777                    data_loc + np.asarray([-h_shift, v_shift]),
778                    data_loc + np.asarray([0, v_shift]),
779                ]
780
781                ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
782            # if not, fall back to the user specified behavior
783            else:
784                ret = base_connection_style(posA, posB, *args, **kwargs)
785
786            return ret
787
788        # FancyArrowPatch doesn't handle color strings
789        arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
790        for i, (src, dst) in enumerate(edge_pos):
791            x1, y1 = src
792            x2, y2 = dst
793            shrink_source = 0  # space from source to tail
794            shrink_target = 0  # space from  head to target
795            if np.iterable(node_size):  # many node sizes
796                source, target = edgelist[i][:2]
797                source_node_size = node_size[nodelist.index(source)]
798                target_node_size = node_size[nodelist.index(target)]
799                shrink_source = to_marker_edge(source_node_size, node_shape)
800                shrink_target = to_marker_edge(target_node_size, node_shape)
801            else:
802                shrink_source = shrink_target = to_marker_edge(node_size, node_shape)
803
804            if shrink_source < min_source_margin:
805                shrink_source = min_source_margin
806
807            if shrink_target < min_target_margin:
808                shrink_target = min_target_margin
809
810            if len(arrow_colors) == len(edge_pos):
811                arrow_color = arrow_colors[i]
812            elif len(arrow_colors) == 1:
813                arrow_color = arrow_colors[0]
814            else:  # Cycle through colors
815                arrow_color = arrow_colors[i % len(arrow_colors)]
816
817            if np.iterable(width):
818                if len(width) == len(edge_pos):
819                    line_width = width[i]
820                else:
821                    line_width = width[i % len(width)]
822            else:
823                line_width = width
824
825            arrow = mpl.patches.FancyArrowPatch(
826                (x1, y1),
827                (x2, y2),
828                arrowstyle=arrowstyle,
829                shrinkA=shrink_source,
830                shrinkB=shrink_target,
831                mutation_scale=mutation_scale,
832                color=arrow_color,
833                linewidth=line_width,
834                connectionstyle=_connectionstyle,
835                linestyle=style,
836                zorder=1,
837            )  # arrows go behind nodes
838
839            arrow_collection.append(arrow)
840            ax.add_patch(arrow)
841
842        return arrow_collection
843
844    # compute initial view
845    minx = np.amin(np.ravel(edge_pos[:, :, 0]))
846    maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
847    miny = np.amin(np.ravel(edge_pos[:, :, 1]))
848    maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
849    w = maxx - minx
850    h = maxy - miny
851
852    # Draw the edges
853    if use_linecollection:
854        edge_viz_obj = _draw_networkx_edges_line_collection()
855        # Make sure selfloop edges are also drawn.
856        edgelist = list(nx.selfloop_edges(G))
857        if edgelist:
858            edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
859            arrowstyle = "-"
860            _draw_networkx_edges_fancy_arrow_patch()
861    else:
862        edge_viz_obj = _draw_networkx_edges_fancy_arrow_patch()
863
864    # update view after drawing
865    padx, pady = 0.05 * w, 0.05 * h
866    corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
867    ax.update_datalim(corners)
868    ax.autoscale_view()
869
870    ax.tick_params(
871        axis="both",
872        which="both",
873        bottom=False,
874        left=False,
875        labelbottom=False,
876        labelleft=False,
877    )
878
879    return edge_viz_obj
880
881
882def draw_networkx_labels(
883    G,
884    pos,
885    labels=None,
886    font_size=12,
887    font_color="k",
888    font_family="sans-serif",
889    font_weight="normal",
890    alpha=None,
891    bbox=None,
892    horizontalalignment="center",
893    verticalalignment="center",
894    ax=None,
895    clip_on=True,
896):
897    """Draw node labels on the graph G.
898
899    Parameters
900    ----------
901    G : graph
902        A networkx graph
903
904    pos : dictionary
905        A dictionary with nodes as keys and positions as values.
906        Positions should be sequences of length 2.
907
908    labels : dictionary (default={n: n for n in G})
909        Node labels in a dictionary of text labels keyed by node.
910        Node-keys in labels should appear as keys in `pos`.
911        If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
912
913    font_size : int (default=12)
914        Font size for text labels
915
916    font_color : string (default='k' black)
917        Font color string
918
919    font_weight : string (default='normal')
920        Font weight
921
922    font_family : string (default='sans-serif')
923        Font family
924
925    alpha : float or None (default=None)
926        The text transparency
927
928    bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
929        Specify text box properties (e.g. shape, color etc.) for node labels.
930
931    horizontalalignment : string (default='center')
932        Horizontal alignment {'center', 'right', 'left'}
933
934    verticalalignment : string (default='center')
935        Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
936
937    ax : Matplotlib Axes object, optional
938        Draw the graph in the specified Matplotlib axes.
939
940    clip_on : bool (default=True)
941        Turn on clipping of node labels at axis boundaries
942
943    Returns
944    -------
945    dict
946        `dict` of labels keyed on the nodes
947
948    Examples
949    --------
950    >>> G = nx.dodecahedral_graph()
951    >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
952
953    Also see the NetworkX drawing examples at
954    https://networkx.org/documentation/latest/auto_examples/index.html
955
956    See Also
957    --------
958    draw
959    draw_networkx
960    draw_networkx_nodes
961    draw_networkx_edges
962    draw_networkx_edge_labels
963    """
964    import matplotlib.pyplot as plt
965
966    if ax is None:
967        ax = plt.gca()
968
969    if labels is None:
970        labels = {n: n for n in G.nodes()}
971
972    text_items = {}  # there is no text collection so we'll fake one
973    for n, label in labels.items():
974        (x, y) = pos[n]
975        if not isinstance(label, str):
976            label = str(label)  # this makes "1" and 1 labeled the same
977        t = ax.text(
978            x,
979            y,
980            label,
981            size=font_size,
982            color=font_color,
983            family=font_family,
984            weight=font_weight,
985            alpha=alpha,
986            horizontalalignment=horizontalalignment,
987            verticalalignment=verticalalignment,
988            transform=ax.transData,
989            bbox=bbox,
990            clip_on=clip_on,
991        )
992        text_items[n] = t
993
994    ax.tick_params(
995        axis="both",
996        which="both",
997        bottom=False,
998        left=False,
999        labelbottom=False,
1000        labelleft=False,
1001    )
1002
1003    return text_items
1004
1005
1006def draw_networkx_edge_labels(
1007    G,
1008    pos,
1009    edge_labels=None,
1010    label_pos=0.5,
1011    font_size=10,
1012    font_color="k",
1013    font_family="sans-serif",
1014    font_weight="normal",
1015    alpha=None,
1016    bbox=None,
1017    horizontalalignment="center",
1018    verticalalignment="center",
1019    ax=None,
1020    rotate=True,
1021    clip_on=True,
1022):
1023    """Draw edge labels.
1024
1025    Parameters
1026    ----------
1027    G : graph
1028        A networkx graph
1029
1030    pos : dictionary
1031        A dictionary with nodes as keys and positions as values.
1032        Positions should be sequences of length 2.
1033
1034    edge_labels : dictionary (default={})
1035        Edge labels in a dictionary of labels keyed by edge two-tuple.
1036        Only labels for the keys in the dictionary are drawn.
1037
1038    label_pos : float (default=0.5)
1039        Position of edge label along edge (0=head, 0.5=center, 1=tail)
1040
1041    font_size : int (default=10)
1042        Font size for text labels
1043
1044    font_color : string (default='k' black)
1045        Font color string
1046
1047    font_weight : string (default='normal')
1048        Font weight
1049
1050    font_family : string (default='sans-serif')
1051        Font family
1052
1053    alpha : float or None (default=None)
1054        The text transparency
1055
1056    bbox : Matplotlib bbox, optional
1057        Specify text box properties (e.g. shape, color etc.) for edge labels.
1058        Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
1059
1060    horizontalalignment : string (default='center')
1061        Horizontal alignment {'center', 'right', 'left'}
1062
1063    verticalalignment : string (default='center')
1064        Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
1065
1066    ax : Matplotlib Axes object, optional
1067        Draw the graph in the specified Matplotlib axes.
1068
1069    rotate : bool (deafult=True)
1070        Rotate edge labels to lie parallel to edges
1071
1072    clip_on : bool (default=True)
1073        Turn on clipping of edge labels at axis boundaries
1074
1075    Returns
1076    -------
1077    dict
1078        `dict` of labels keyed by edge
1079
1080    Examples
1081    --------
1082    >>> G = nx.dodecahedral_graph()
1083    >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
1084
1085    Also see the NetworkX drawing examples at
1086    https://networkx.org/documentation/latest/auto_examples/index.html
1087
1088    See Also
1089    --------
1090    draw
1091    draw_networkx
1092    draw_networkx_nodes
1093    draw_networkx_edges
1094    draw_networkx_labels
1095    """
1096    import matplotlib.pyplot as plt
1097    import numpy as np
1098
1099    if ax is None:
1100        ax = plt.gca()
1101    if edge_labels is None:
1102        labels = {(u, v): d for u, v, d in G.edges(data=True)}
1103    else:
1104        labels = edge_labels
1105    text_items = {}
1106    for (n1, n2), label in labels.items():
1107        (x1, y1) = pos[n1]
1108        (x2, y2) = pos[n2]
1109        (x, y) = (
1110            x1 * label_pos + x2 * (1.0 - label_pos),
1111            y1 * label_pos + y2 * (1.0 - label_pos),
1112        )
1113
1114        if rotate:
1115            # in degrees
1116            angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
1117            # make label orientation "right-side-up"
1118            if angle > 90:
1119                angle -= 180
1120            if angle < -90:
1121                angle += 180
1122            # transform data coordinate angle to screen coordinate angle
1123            xy = np.array((x, y))
1124            trans_angle = ax.transData.transform_angles(
1125                np.array((angle,)), xy.reshape((1, 2))
1126            )[0]
1127        else:
1128            trans_angle = 0.0
1129        # use default box of white with white border
1130        if bbox is None:
1131            bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
1132        if not isinstance(label, str):
1133            label = str(label)  # this makes "1" and 1 labeled the same
1134
1135        t = ax.text(
1136            x,
1137            y,
1138            label,
1139            size=font_size,
1140            color=font_color,
1141            family=font_family,
1142            weight=font_weight,
1143            alpha=alpha,
1144            horizontalalignment=horizontalalignment,
1145            verticalalignment=verticalalignment,
1146            rotation=trans_angle,
1147            transform=ax.transData,
1148            bbox=bbox,
1149            zorder=1,
1150            clip_on=clip_on,
1151        )
1152        text_items[(n1, n2)] = t
1153
1154    ax.tick_params(
1155        axis="both",
1156        which="both",
1157        bottom=False,
1158        left=False,
1159        labelbottom=False,
1160        labelleft=False,
1161    )
1162
1163    return text_items
1164
1165
1166def draw_circular(G, **kwargs):
1167    """Draw the graph G with a circular layout.
1168
1169    Parameters
1170    ----------
1171    G : graph
1172        A networkx graph
1173
1174    kwargs : optional keywords
1175        See networkx.draw_networkx() for a description of optional keywords,
1176        with the exception of the pos parameter which is not used by this
1177        function.
1178    """
1179    draw(G, circular_layout(G), **kwargs)
1180
1181
1182def draw_kamada_kawai(G, **kwargs):
1183    """Draw the graph G with a Kamada-Kawai force-directed layout.
1184
1185    Parameters
1186    ----------
1187    G : graph
1188        A networkx graph
1189
1190    kwargs : optional keywords
1191        See networkx.draw_networkx() for a description of optional keywords,
1192        with the exception of the pos parameter which is not used by this
1193        function.
1194    """
1195    draw(G, kamada_kawai_layout(G), **kwargs)
1196
1197
1198def draw_random(G, **kwargs):
1199    """Draw the graph G with a random layout.
1200
1201    Parameters
1202    ----------
1203    G : graph
1204        A networkx graph
1205
1206    kwargs : optional keywords
1207        See networkx.draw_networkx() for a description of optional keywords,
1208        with the exception of the pos parameter which is not used by this
1209        function.
1210    """
1211    draw(G, random_layout(G), **kwargs)
1212
1213
1214def draw_spectral(G, **kwargs):
1215    """Draw the graph G with a spectral 2D layout.
1216
1217    Using the unnormalized Laplacian, the layout shows possible clusters of
1218    nodes which are an approximation of the ratio cut. The positions are the
1219    entries of the second and third eigenvectors corresponding to the
1220    ascending eigenvalues starting from the second one.
1221
1222    Parameters
1223    ----------
1224    G : graph
1225        A networkx graph
1226
1227    kwargs : optional keywords
1228        See networkx.draw_networkx() for a description of optional keywords,
1229        with the exception of the pos parameter which is not used by this
1230        function.
1231    """
1232    draw(G, spectral_layout(G), **kwargs)
1233
1234
1235def draw_spring(G, **kwargs):
1236    """Draw the graph G with a spring layout.
1237
1238    Parameters
1239    ----------
1240    G : graph
1241        A networkx graph
1242
1243    kwargs : optional keywords
1244        See networkx.draw_networkx() for a description of optional keywords,
1245        with the exception of the pos parameter which is not used by this
1246        function.
1247    """
1248    draw(G, spring_layout(G), **kwargs)
1249
1250
1251def draw_shell(G, **kwargs):
1252    """Draw networkx graph with shell layout.
1253
1254    Parameters
1255    ----------
1256    G : graph
1257        A networkx graph
1258
1259    kwargs : optional keywords
1260        See networkx.draw_networkx() for a description of optional keywords,
1261        with the exception of the pos parameter which is not used by this
1262        function.
1263    """
1264    nlist = kwargs.get("nlist", None)
1265    if nlist is not None:
1266        del kwargs["nlist"]
1267    draw(G, shell_layout(G, nlist=nlist), **kwargs)
1268
1269
1270def draw_planar(G, **kwargs):
1271    """Draw a planar networkx graph with planar layout.
1272
1273    Parameters
1274    ----------
1275    G : graph
1276        A planar networkx graph
1277
1278    kwargs : optional keywords
1279        See networkx.draw_networkx() for a description of optional keywords,
1280        with the exception of the pos parameter which is not used by this
1281        function.
1282    """
1283    draw(G, planar_layout(G), **kwargs)
1284
1285
1286def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
1287    """Apply an alpha (or list of alphas) to the colors provided.
1288
1289    Parameters
1290    ----------
1291
1292    colors : color string or array of floats (default='r')
1293        Color of element. Can be a single color format string,
1294        or a sequence of colors with the same length as nodelist.
1295        If numeric values are specified they will be mapped to
1296        colors using the cmap and vmin,vmax parameters.  See
1297        matplotlib.scatter for more details.
1298
1299    alpha : float or array of floats
1300        Alpha values for elements. This can be a single alpha value, in
1301        which case it will be applied to all the elements of color. Otherwise,
1302        if it is an array, the elements of alpha will be applied to the colors
1303        in order (cycling through alpha multiple times if necessary).
1304
1305    elem_list : array of networkx objects
1306        The list of elements which are being colored. These could be nodes,
1307        edges or labels.
1308
1309    cmap : matplotlib colormap
1310        Color map for use if colors is a list of floats corresponding to points
1311        on a color mapping.
1312
1313    vmin, vmax : float
1314        Minimum and maximum values for normalizing colors if a colormap is used
1315
1316    Returns
1317    -------
1318
1319    rgba_colors : numpy ndarray
1320        Array containing RGBA format values for each of the node colours.
1321
1322    """
1323    from itertools import islice, cycle
1324    import numpy as np
1325    import matplotlib as mpl
1326    import matplotlib.colors  # call as mpl.colors
1327    import matplotlib.cm  # call as mpl.cm
1328
1329    # If we have been provided with a list of numbers as long as elem_list,
1330    # apply the color mapping.
1331    if len(colors) == len(elem_list) and isinstance(colors[0], Number):
1332        mapper = mpl.cm.ScalarMappable(cmap=cmap)
1333        mapper.set_clim(vmin, vmax)
1334        rgba_colors = mapper.to_rgba(colors)
1335    # Otherwise, convert colors to matplotlib's RGB using the colorConverter
1336    # object.  These are converted to numpy ndarrays to be consistent with the
1337    # to_rgba method of ScalarMappable.
1338    else:
1339        try:
1340            rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
1341        except ValueError:
1342            rgba_colors = np.array(
1343                [mpl.colors.colorConverter.to_rgba(color) for color in colors]
1344            )
1345    # Set the final column of the rgba_colors to have the relevant alpha values
1346    try:
1347        # If alpha is longer than the number of colors, resize to the number of
1348        # elements.  Also, if rgba_colors.size (the number of elements of
1349        # rgba_colors) is the same as the number of elements, resize the array,
1350        # to avoid it being interpreted as a colormap by scatter()
1351        if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
1352            rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
1353            rgba_colors[1:, 0] = rgba_colors[0, 0]
1354            rgba_colors[1:, 1] = rgba_colors[0, 1]
1355            rgba_colors[1:, 2] = rgba_colors[0, 2]
1356        rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
1357    except TypeError:
1358        rgba_colors[:, -1] = alpha
1359    return rgba_colors
1360