1import plotly.graph_objs as go
2import plotly.io as pio
3from collections import namedtuple, OrderedDict
4from ._special_inputs import IdentityMap, Constant, Range
5
6from _plotly_utils.basevalidators import ColorscaleValidator
7from plotly.colors import qualitative, sequential
8import math
9import pandas as pd
10import numpy as np
11
12from plotly.subplots import (
13    make_subplots,
14    _set_trace_grid_reference,
15    _subplot_type_for_trace_type,
16)
17
18NO_COLOR = "px_no_color_constant"
19
20# Declare all supported attributes, across all plot types
21direct_attrables = (
22    ["base", "x", "y", "z", "a", "b", "c", "r", "theta", "size", "x_start", "x_end"]
23    + ["hover_name", "text", "names", "values", "parents", "wide_cross"]
24    + ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"]
25    + ["error_z_minus", "lat", "lon", "locations", "animation_group"]
26)
27array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_variable"]
28group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
29renameable_group_attrables = [
30    "color",  # renamed to marker.color or line.color in infer_config
31    "symbol",  # renamed to marker.symbol in infer_config
32    "line_dash",  # renamed to line.dash in infer_config
33]
34all_attrables = (
35    direct_attrables + array_attrables + group_attrables + renameable_group_attrables
36)
37
38cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin]
39cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour]
40
41
42class PxDefaults(object):
43    __slots__ = [
44        "template",
45        "width",
46        "height",
47        "color_discrete_sequence",
48        "color_discrete_map",
49        "color_continuous_scale",
50        "symbol_sequence",
51        "symbol_map",
52        "line_dash_sequence",
53        "line_dash_map",
54        "size_max",
55        "category_orders",
56        "labels",
57    ]
58
59    def __init__(self):
60        self.reset()
61
62    def reset(self):
63        self.template = None
64        self.width = None
65        self.height = None
66        self.color_discrete_sequence = None
67        self.color_discrete_map = {}
68        self.color_continuous_scale = None
69        self.symbol_sequence = None
70        self.symbol_map = {}
71        self.line_dash_sequence = None
72        self.line_dash_map = {}
73        self.size_max = 20
74        self.category_orders = {}
75        self.labels = {}
76
77
78defaults = PxDefaults()
79del PxDefaults
80
81
82MAPBOX_TOKEN = None
83
84
85def set_mapbox_access_token(token):
86    """
87    Arguments:
88        token: A Mapbox token to be used in `plotly.express.scatter_mapbox` and \
89        `plotly.express.line_mapbox` figures. See \
90        https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details
91    """
92    global MAPBOX_TOKEN
93    MAPBOX_TOKEN = token
94
95
96def get_trendline_results(fig):
97    """
98    Extracts fit statistics for trendlines (when applied to figures generated with
99    the `trendline` argument set to `"ols"`).
100
101    Arguments:
102        fig: the output of a `plotly.express` charting call
103    Returns:
104        A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
105        results objects, along with columns identifying the subset of the data the
106        trendline was fit on.
107    """
108    return fig._px_trendlines
109
110
111Mapping = namedtuple(
112    "Mapping",
113    [
114        "show_in_trace_name",
115        "grouper",
116        "val_map",
117        "sequence",
118        "updater",
119        "variable",
120        "facet",
121    ],
122)
123TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"])
124
125
126def get_label(args, column):
127    try:
128        return args["labels"][column]
129    except Exception:
130        return column
131
132
133def invert_label(args, column):
134    """Invert mapping.
135    Find key corresponding to value column in dict args["labels"].
136    Returns `column` if the value does not exist.
137    """
138    reversed_labels = {value: key for (key, value) in args["labels"].items()}
139    try:
140        return reversed_labels[column]
141    except Exception:
142        return column
143
144
145def _is_continuous(df, col_name):
146    return df[col_name].dtype.kind in "ifc"
147
148
149def get_decorated_label(args, column, role):
150    original_label = label = get_label(args, column)
151    if "histfunc" in args and (
152        (role == "z")
153        or (role == "x" and "orientation" in args and args["orientation"] == "h")
154        or (role == "y" and "orientation" in args and args["orientation"] == "v")
155    ):
156        histfunc = args["histfunc"] or "count"
157        if histfunc != "count":
158            label = "%s of %s" % (histfunc, label)
159        else:
160            label = "count"
161
162        if "histnorm" in args and args["histnorm"] is not None:
163            if label == "count":
164                label = args["histnorm"]
165            else:
166                histnorm = args["histnorm"]
167                if histfunc == "sum":
168                    if histnorm == "probability":
169                        label = "%s of %s" % ("fraction", label)
170                    elif histnorm == "percent":
171                        label = "%s of %s" % (histnorm, label)
172                    else:
173                        label = "%s weighted by %s" % (histnorm, original_label)
174                elif histnorm == "probability":
175                    label = "%s of sum of %s" % ("fraction", label)
176                elif histnorm == "percent":
177                    label = "%s of sum of %s" % ("percent", label)
178                else:
179                    label = "%s of %s" % (histnorm, label)
180
181        if "barnorm" in args and args["barnorm"] is not None:
182            label = "%s (normalized as %s)" % (label, args["barnorm"])
183
184    return label
185
186
187def make_mapping(args, variable):
188    if variable == "line_group" or variable == "animation_frame":
189        return Mapping(
190            show_in_trace_name=False,
191            grouper=args[variable],
192            val_map={},
193            sequence=[""],
194            variable=variable,
195            updater=(lambda trace, v: v),
196            facet=None,
197        )
198    if variable == "facet_row" or variable == "facet_col":
199        letter = "x" if variable == "facet_col" else "y"
200        return Mapping(
201            show_in_trace_name=False,
202            variable=letter,
203            grouper=args[variable],
204            val_map={},
205            sequence=[i for i in range(1, 1000)],
206            updater=(lambda trace, v: v),
207            facet="row" if variable == "facet_row" else "col",
208        )
209    (parent, variable) = variable.split(".")
210    vprefix = variable
211    arg_name = variable
212    if variable == "color":
213        vprefix = "color_discrete"
214    if variable == "dash":
215        arg_name = "line_dash"
216        vprefix = "line_dash"
217    if args[vprefix + "_map"] == "identity":
218        val_map = IdentityMap()
219    else:
220        val_map = args[vprefix + "_map"].copy()
221    return Mapping(
222        show_in_trace_name=True,
223        variable=variable,
224        grouper=args[arg_name],
225        val_map=val_map,
226        sequence=args[vprefix + "_sequence"],
227        updater=lambda trace, v: trace.update({parent: {variable: v}}),
228        facet=None,
229    )
230
231
232def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
233    """Populates a dict with arguments to update trace
234
235    Parameters
236    ----------
237    args : dict
238        args to be used for the trace
239    trace_spec : NamedTuple
240        which kind of trace to be used (has constructor, marginal etc.
241        attributes)
242    trace_data : pandas DataFrame
243        data
244    mapping_labels : dict
245        to be used for hovertemplate
246    sizeref : float
247        marker sizeref
248
249    Returns
250    -------
251    trace_patch : dict
252        dict to be used to update trace
253    fit_results : dict
254        fit information to be used for trendlines
255    """
256    if "line_close" in args and args["line_close"]:
257        trace_data = trace_data.append(trace_data.iloc[0])
258    trace_patch = trace_spec.trace_patch.copy() or {}
259    fit_results = None
260    hover_header = ""
261    for attr_name in trace_spec.attrs:
262        attr_value = args[attr_name]
263        attr_label = get_decorated_label(args, attr_value, attr_name)
264        if attr_name == "dimensions":
265            dims = [
266                (name, column)
267                for (name, column) in trace_data.iteritems()
268                if ((not attr_value) or (name in attr_value))
269                and (
270                    trace_spec.constructor != go.Parcoords
271                    or _is_continuous(args["data_frame"], name)
272                )
273                and (
274                    trace_spec.constructor != go.Parcats
275                    or (attr_value is not None and name in attr_value)
276                    or len(args["data_frame"][name].unique())
277                    <= args["dimensions_max_cardinality"]
278                )
279            ]
280            trace_patch["dimensions"] = [
281                dict(label=get_label(args, name), values=column)
282                for (name, column) in dims
283            ]
284            if trace_spec.constructor == go.Splom:
285                for d in trace_patch["dimensions"]:
286                    d["axis"] = dict(matches=True)
287                mapping_labels["%{xaxis.title.text}"] = "%{x}"
288                mapping_labels["%{yaxis.title.text}"] = "%{y}"
289
290        elif attr_value is not None:
291            if attr_name == "size":
292                if "marker" not in trace_patch:
293                    trace_patch["marker"] = dict()
294                trace_patch["marker"]["size"] = trace_data[attr_value]
295                trace_patch["marker"]["sizemode"] = "area"
296                trace_patch["marker"]["sizeref"] = sizeref
297                mapping_labels[attr_label] = "%{marker.size}"
298            elif attr_name == "marginal_x":
299                if trace_spec.constructor == go.Histogram:
300                    mapping_labels["count"] = "%{y}"
301            elif attr_name == "marginal_y":
302                if trace_spec.constructor == go.Histogram:
303                    mapping_labels["count"] = "%{x}"
304            elif attr_name == "trendline":
305                if (
306                    attr_value in ["ols", "lowess"]
307                    and args["x"]
308                    and args["y"]
309                    and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
310                ):
311                    import statsmodels.api as sm
312
313                    # sorting is bad but trace_specs with "trendline" have no other attrs
314                    sorted_trace_data = trace_data.sort_values(by=args["x"])
315                    y = sorted_trace_data[args["y"]].values
316                    x = sorted_trace_data[args["x"]].values
317
318                    if x.dtype.type == np.datetime64:
319                        x = x.astype(int) / 10 ** 9  # convert to unix epoch seconds
320                    elif x.dtype.type == np.object_:
321                        try:
322                            x = x.astype(np.float64)
323                        except ValueError:
324                            raise ValueError(
325                                "Could not convert value of 'x' ('%s') into a numeric type. "
326                                "If 'x' contains stringified dates, please convert to a datetime column."
327                                % args["x"]
328                            )
329                    if y.dtype.type == np.object_:
330                        try:
331                            y = y.astype(np.float64)
332                        except ValueError:
333                            raise ValueError(
334                                "Could not convert value of 'y' into a numeric type."
335                            )
336
337                    # preserve original values of "x" in case they're dates
338                    trace_patch["x"] = sorted_trace_data[args["x"]][
339                        np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
340                    ]
341
342                    if attr_value == "lowess":
343                        # missing ='drop' is the default value for lowess but not for OLS (None)
344                        # we force it here in case statsmodels change their defaults
345                        trendline = sm.nonparametric.lowess(y, x, missing="drop")
346                        trace_patch["y"] = trendline[:, 1]
347                        hover_header = "<b>LOWESS trendline</b><br><br>"
348                    elif attr_value == "ols":
349                        fit_results = sm.OLS(
350                            y, sm.add_constant(x), missing="drop"
351                        ).fit()
352                        trace_patch["y"] = fit_results.predict()
353                        hover_header = "<b>OLS trendline</b><br>"
354                        if len(fit_results.params) == 2:
355                            hover_header += "%s = %g * %s + %g<br>" % (
356                                args["y"],
357                                fit_results.params[1],
358                                args["x"],
359                                fit_results.params[0],
360                            )
361                        else:
362                            hover_header += "%s = %g<br>" % (
363                                args["y"],
364                                fit_results.params[0],
365                            )
366                        hover_header += (
367                            "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
368                        )
369                    mapping_labels[get_label(args, args["x"])] = "%{x}"
370                    mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
371            elif attr_name.startswith("error"):
372                error_xy = attr_name[:7]
373                arr = "arrayminus" if attr_name.endswith("minus") else "array"
374                if error_xy not in trace_patch:
375                    trace_patch[error_xy] = {}
376                trace_patch[error_xy][arr] = trace_data[attr_value]
377            elif attr_name == "custom_data":
378                if len(attr_value) > 0:
379                    # here we store a data frame in customdata, and it's serialized
380                    # as a list of row lists, which is what we want
381                    trace_patch["customdata"] = trace_data[attr_value]
382            elif attr_name == "hover_name":
383                if trace_spec.constructor not in [
384                    go.Histogram,
385                    go.Histogram2d,
386                    go.Histogram2dContour,
387                ]:
388                    trace_patch["hovertext"] = trace_data[attr_value]
389                    if hover_header == "":
390                        hover_header = "<b>%{hovertext}</b><br><br>"
391            elif attr_name == "hover_data":
392                if trace_spec.constructor not in [
393                    go.Histogram,
394                    go.Histogram2d,
395                    go.Histogram2dContour,
396                ]:
397                    hover_is_dict = isinstance(attr_value, dict)
398                    customdata_cols = args.get("custom_data") or []
399                    for col in attr_value:
400                        if hover_is_dict and not attr_value[col]:
401                            continue
402                        if col in [
403                            args.get("x", None),
404                            args.get("y", None),
405                            args.get("z", None),
406                            args.get("base", None),
407                        ]:
408                            continue
409                        try:
410                            position = args["custom_data"].index(col)
411                        except (ValueError, AttributeError, KeyError):
412                            position = len(customdata_cols)
413                            customdata_cols.append(col)
414                        attr_label_col = get_decorated_label(args, col, None)
415                        mapping_labels[attr_label_col] = "%%{customdata[%d]}" % (
416                            position
417                        )
418
419                    if len(customdata_cols) > 0:
420                        # here we store a data frame in customdata, and it's serialized
421                        # as a list of row lists, which is what we want
422                        trace_patch["customdata"] = trace_data[customdata_cols]
423            elif attr_name == "color":
424                if trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]:
425                    trace_patch["z"] = trace_data[attr_value]
426                    trace_patch["coloraxis"] = "coloraxis1"
427                    mapping_labels[attr_label] = "%{z}"
428                elif trace_spec.constructor in [
429                    go.Sunburst,
430                    go.Treemap,
431                    go.Pie,
432                    go.Funnelarea,
433                ]:
434                    if "marker" not in trace_patch:
435                        trace_patch["marker"] = dict()
436
437                    if args.get("color_is_continuous"):
438                        trace_patch["marker"]["colors"] = trace_data[attr_value]
439                        trace_patch["marker"]["coloraxis"] = "coloraxis1"
440                        mapping_labels[attr_label] = "%{color}"
441                    else:
442                        trace_patch["marker"]["colors"] = []
443                        if args["color_discrete_map"] is not None:
444                            mapping = args["color_discrete_map"].copy()
445                        else:
446                            mapping = {}
447                        for cat in trace_data[attr_value]:
448                            if mapping.get(cat) is None:
449                                mapping[cat] = args["color_discrete_sequence"][
450                                    len(mapping) % len(args["color_discrete_sequence"])
451                                ]
452                            trace_patch["marker"]["colors"].append(mapping[cat])
453                else:
454                    colorable = "marker"
455                    if trace_spec.constructor in [go.Parcats, go.Parcoords]:
456                        colorable = "line"
457                    if colorable not in trace_patch:
458                        trace_patch[colorable] = dict()
459                    trace_patch[colorable]["color"] = trace_data[attr_value]
460                    trace_patch[colorable]["coloraxis"] = "coloraxis1"
461                    mapping_labels[attr_label] = "%%{%s.color}" % colorable
462            elif attr_name == "animation_group":
463                trace_patch["ids"] = trace_data[attr_value]
464            elif attr_name == "locations":
465                trace_patch[attr_name] = trace_data[attr_value]
466                mapping_labels[attr_label] = "%{location}"
467            elif attr_name == "values":
468                trace_patch[attr_name] = trace_data[attr_value]
469                _label = "value" if attr_label == "values" else attr_label
470                mapping_labels[_label] = "%{value}"
471            elif attr_name == "parents":
472                trace_patch[attr_name] = trace_data[attr_value]
473                _label = "parent" if attr_label == "parents" else attr_label
474                mapping_labels[_label] = "%{parent}"
475            elif attr_name == "ids":
476                trace_patch[attr_name] = trace_data[attr_value]
477                _label = "id" if attr_label == "ids" else attr_label
478                mapping_labels[_label] = "%{id}"
479            elif attr_name == "names":
480                if trace_spec.constructor in [
481                    go.Sunburst,
482                    go.Treemap,
483                    go.Pie,
484                    go.Funnelarea,
485                ]:
486                    trace_patch["labels"] = trace_data[attr_value]
487                    _label = "label" if attr_label == "names" else attr_label
488                    mapping_labels[_label] = "%{label}"
489                else:
490                    trace_patch[attr_name] = trace_data[attr_value]
491            else:
492                trace_patch[attr_name] = trace_data[attr_value]
493                mapping_labels[attr_label] = "%%{%s}" % attr_name
494        elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or (
495            trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour]
496            and attr_name == "z"
497        ):
498            # ensure that stuff like "count" gets into the hoverlabel
499            mapping_labels[attr_label] = "%%{%s}" % attr_name
500    if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
501        # Modify mapping_labels according to hover_data keys
502        # if hover_data is a dict
503        mapping_labels_copy = OrderedDict(mapping_labels)
504        if args["hover_data"] and isinstance(args["hover_data"], dict):
505            for k, v in mapping_labels.items():
506                # We need to invert the mapping here
507                k_args = invert_label(args, k)
508                if k_args in args["hover_data"]:
509                    formatter = args["hover_data"][k_args][0]
510                    if formatter:
511                        if isinstance(formatter, str):
512                            mapping_labels_copy[k] = v.replace("}", "%s}" % formatter)
513                    else:
514                        _ = mapping_labels_copy.pop(k)
515        hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()]
516        trace_patch["hovertemplate"] = hover_header + "<br>".join(hover_lines)
517        trace_patch["hovertemplate"] += "<extra></extra>"
518    return trace_patch, fit_results
519
520
521def configure_axes(args, constructor, fig, orders):
522    configurators = {
523        go.Scatter3d: configure_3d_axes,
524        go.Scatterternary: configure_ternary_axes,
525        go.Scatterpolar: configure_polar_axes,
526        go.Scatterpolargl: configure_polar_axes,
527        go.Barpolar: configure_polar_axes,
528        go.Scattermapbox: configure_mapbox,
529        go.Choroplethmapbox: configure_mapbox,
530        go.Densitymapbox: configure_mapbox,
531        go.Scattergeo: configure_geo,
532        go.Choropleth: configure_geo,
533    }
534    for c in cartesians:
535        configurators[c] = configure_cartesian_axes
536    if constructor in configurators:
537        configurators[constructor](args, fig, orders)
538
539
540def set_cartesian_axis_opts(args, axis, letter, orders):
541    log_key = "log_" + letter
542    range_key = "range_" + letter
543    if log_key in args and args[log_key]:
544        axis["type"] = "log"
545        if range_key in args and args[range_key]:
546            axis["range"] = [math.log(r, 10) for r in args[range_key]]
547    elif range_key in args and args[range_key]:
548        axis["range"] = args[range_key]
549
550    if args[letter] in orders:
551        axis["categoryorder"] = "array"
552        axis["categoryarray"] = (
553            orders[args[letter]]
554            if isinstance(axis, go.layout.XAxis)
555            else list(reversed(orders[args[letter]]))
556        )
557
558
559def configure_cartesian_marginal_axes(args, fig, orders):
560
561    if "histogram" in [args["marginal_x"], args["marginal_y"]]:
562        fig.layout["barmode"] = "overlay"
563
564    nrows = len(fig._grid_ref)
565    ncols = len(fig._grid_ref[0])
566
567    # Set y-axis titles and axis options in the left-most column
568    for yaxis in fig.select_yaxes(col=1):
569        set_cartesian_axis_opts(args, yaxis, "y", orders)
570
571    # Set x-axis titles and axis options in the bottom-most row
572    for xaxis in fig.select_xaxes(row=1):
573        set_cartesian_axis_opts(args, xaxis, "x", orders)
574
575    # Configure axis ticks on marginal subplots
576    if args["marginal_x"]:
577        fig.update_yaxes(
578            showticklabels=False, showline=False, ticks="", range=None, row=nrows
579        )
580        if args["template"].layout.yaxis.showgrid is None:
581            fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
582        if args["template"].layout.xaxis.showgrid is None:
583            fig.update_xaxes(showgrid=True, row=nrows)
584
585    if args["marginal_y"]:
586        fig.update_xaxes(
587            showticklabels=False, showline=False, ticks="", range=None, col=ncols
588        )
589        if args["template"].layout.xaxis.showgrid is None:
590            fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
591        if args["template"].layout.yaxis.showgrid is None:
592            fig.update_yaxes(showgrid=True, col=ncols)
593
594    # Add axis titles to non-marginal subplots
595    y_title = get_decorated_label(args, args["y"], "y")
596    if args["marginal_x"]:
597        fig.update_yaxes(title_text=y_title, row=1, col=1)
598    else:
599        for row in range(1, nrows + 1):
600            fig.update_yaxes(title_text=y_title, row=row, col=1)
601
602    x_title = get_decorated_label(args, args["x"], "x")
603    if args["marginal_y"]:
604        fig.update_xaxes(title_text=x_title, row=1, col=1)
605    else:
606        for col in range(1, ncols + 1):
607            fig.update_xaxes(title_text=x_title, row=1, col=col)
608
609    # Configure axis type across all x-axes
610    if "log_x" in args and args["log_x"]:
611        fig.update_xaxes(type="log")
612
613    # Configure axis type across all y-axes
614    if "log_y" in args and args["log_y"]:
615        fig.update_yaxes(type="log")
616
617    # Configure matching and axis type for marginal y-axes
618    matches_y = "y" + str(ncols + 1)
619    if args["marginal_x"]:
620        for row in range(2, nrows + 1, 2):
621            fig.update_yaxes(matches=matches_y, type=None, row=row)
622
623    if args["marginal_y"]:
624        for col in range(2, ncols + 1, 2):
625            fig.update_xaxes(matches="x2", type=None, col=col)
626
627
628def configure_cartesian_axes(args, fig, orders):
629    if ("marginal_x" in args and args["marginal_x"]) or (
630        "marginal_y" in args and args["marginal_y"]
631    ):
632        configure_cartesian_marginal_axes(args, fig, orders)
633        return
634
635    # Set y-axis titles and axis options in the left-most column
636    y_title = get_decorated_label(args, args["y"], "y")
637    for yaxis in fig.select_yaxes(col=1):
638        yaxis.update(title_text=y_title)
639        set_cartesian_axis_opts(args, yaxis, "y", orders)
640
641    # Set x-axis titles and axis options in the bottom-most row
642    x_title = get_decorated_label(args, args["x"], "x")
643    for xaxis in fig.select_xaxes(row=1):
644        if "is_timeline" not in args:
645            xaxis.update(title_text=x_title)
646        set_cartesian_axis_opts(args, xaxis, "x", orders)
647
648    # Configure axis type across all x-axes
649    if "log_x" in args and args["log_x"]:
650        fig.update_xaxes(type="log")
651
652    # Configure axis type across all y-axes
653    if "log_y" in args and args["log_y"]:
654        fig.update_yaxes(type="log")
655
656    if "is_timeline" in args:
657        fig.update_xaxes(type="date")
658
659
660def configure_ternary_axes(args, fig, orders):
661    fig.update_ternaries(
662        aaxis=dict(title_text=get_label(args, args["a"])),
663        baxis=dict(title_text=get_label(args, args["b"])),
664        caxis=dict(title_text=get_label(args, args["c"])),
665    )
666
667
668def configure_polar_axes(args, fig, orders):
669    patch = dict(
670        angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
671        radialaxis=dict(),
672    )
673
674    for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]:
675        if args[var] in orders:
676            patch[axis]["categoryorder"] = "array"
677            patch[axis]["categoryarray"] = orders[args[var]]
678
679    radialaxis = patch["radialaxis"]
680    if args["log_r"]:
681        radialaxis["type"] = "log"
682        if args["range_r"]:
683            radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]]
684    else:
685        if args["range_r"]:
686            radialaxis["range"] = args["range_r"]
687
688    if args["range_theta"]:
689        patch["sector"] = args["range_theta"]
690    fig.update_polars(patch)
691
692
693def configure_3d_axes(args, fig, orders):
694    patch = dict(
695        xaxis=dict(title_text=get_label(args, args["x"])),
696        yaxis=dict(title_text=get_label(args, args["y"])),
697        zaxis=dict(title_text=get_label(args, args["z"])),
698    )
699
700    for letter in ["x", "y", "z"]:
701        axis = patch[letter + "axis"]
702        if args["log_" + letter]:
703            axis["type"] = "log"
704            if args["range_" + letter]:
705                axis["range"] = [math.log(x, 10) for x in args["range_" + letter]]
706        else:
707            if args["range_" + letter]:
708                axis["range"] = args["range_" + letter]
709        if args[letter] in orders:
710            axis["categoryorder"] = "array"
711            axis["categoryarray"] = orders[args[letter]]
712    fig.update_scenes(patch)
713
714
715def configure_mapbox(args, fig, orders):
716    center = args["center"]
717    if not center and "lat" in args and "lon" in args:
718        center = dict(
719            lat=args["data_frame"][args["lat"]].mean(),
720            lon=args["data_frame"][args["lon"]].mean(),
721        )
722    fig.update_mapboxes(
723        accesstoken=MAPBOX_TOKEN,
724        center=center,
725        zoom=args["zoom"],
726        style=args["mapbox_style"],
727    )
728
729
730def configure_geo(args, fig, orders):
731    fig.update_geos(
732        center=args["center"],
733        scope=args["scope"],
734        fitbounds=args["fitbounds"],
735        visible=args["basemap_visible"],
736        projection=dict(type=args["projection"]),
737    )
738
739
740def configure_animation_controls(args, constructor, fig):
741    def frame_args(duration):
742        return {
743            "frame": {"duration": duration, "redraw": constructor != go.Scatter},
744            "mode": "immediate",
745            "fromcurrent": True,
746            "transition": {"duration": duration, "easing": "linear"},
747        }
748
749    if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1:
750        fig.layout.updatemenus = [
751            {
752                "buttons": [
753                    {
754                        "args": [None, frame_args(500)],
755                        "label": "&#9654;",
756                        "method": "animate",
757                    },
758                    {
759                        "args": [[None], frame_args(0)],
760                        "label": "&#9724;",
761                        "method": "animate",
762                    },
763                ],
764                "direction": "left",
765                "pad": {"r": 10, "t": 70},
766                "showactive": False,
767                "type": "buttons",
768                "x": 0.1,
769                "xanchor": "right",
770                "y": 0,
771                "yanchor": "top",
772            }
773        ]
774        fig.layout.sliders = [
775            {
776                "active": 0,
777                "yanchor": "top",
778                "xanchor": "left",
779                "currentvalue": {
780                    "prefix": get_label(args, args["animation_frame"]) + "="
781                },
782                "pad": {"b": 10, "t": 60},
783                "len": 0.9,
784                "x": 0.1,
785                "y": 0,
786                "steps": [
787                    {
788                        "args": [[f.name], frame_args(0)],
789                        "label": f.name,
790                        "method": "animate",
791                    }
792                    for f in fig.frames
793                ],
794            }
795        ]
796
797
798def make_trace_spec(args, constructor, attrs, trace_patch):
799    if constructor in [go.Scatter, go.Scatterpolar]:
800        if "render_mode" in args and (
801            args["render_mode"] == "webgl"
802            or (
803                args["render_mode"] == "auto"
804                and len(args["data_frame"]) > 1000
805                and args["animation_frame"] is None
806            )
807        ):
808            if constructor == go.Scatter:
809                constructor = go.Scattergl
810                if "orientation" in trace_patch:
811                    del trace_patch["orientation"]
812            else:
813                constructor = go.Scatterpolargl
814    # Create base trace specification
815    result = [TraceSpec(constructor, attrs, trace_patch, None)]
816
817    # Add marginal trace specifications
818    for letter in ["x", "y"]:
819        if "marginal_" + letter in args and args["marginal_" + letter]:
820            trace_spec = None
821            axis_map = dict(
822                xaxis="x1" if letter == "x" else "x2",
823                yaxis="y1" if letter == "y" else "y2",
824            )
825            if args["marginal_" + letter] == "histogram":
826                trace_spec = TraceSpec(
827                    constructor=go.Histogram,
828                    attrs=[letter, "marginal_" + letter],
829                    trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map),
830                    marginal=letter,
831                )
832            elif args["marginal_" + letter] == "violin":
833                trace_spec = TraceSpec(
834                    constructor=go.Violin,
835                    attrs=[letter, "hover_name", "hover_data"],
836                    trace_patch=dict(scalegroup=letter),
837                    marginal=letter,
838                )
839            elif args["marginal_" + letter] == "box":
840                trace_spec = TraceSpec(
841                    constructor=go.Box,
842                    attrs=[letter, "hover_name", "hover_data"],
843                    trace_patch=dict(notched=True),
844                    marginal=letter,
845                )
846            elif args["marginal_" + letter] == "rug":
847                symbols = {"x": "line-ns-open", "y": "line-ew-open"}
848                trace_spec = TraceSpec(
849                    constructor=go.Box,
850                    attrs=[letter, "hover_name", "hover_data"],
851                    trace_patch=dict(
852                        fillcolor="rgba(255,255,255,0)",
853                        line={"color": "rgba(255,255,255,0)"},
854                        boxpoints="all",
855                        jitter=0,
856                        hoveron="points",
857                        marker={"symbol": symbols[letter]},
858                    ),
859                    marginal=letter,
860                )
861            if "color" in attrs or "color" not in args:
862                if "marker" not in trace_spec.trace_patch:
863                    trace_spec.trace_patch["marker"] = dict()
864                first_default_color = args["color_continuous_scale"][0]
865                trace_spec.trace_patch["marker"]["color"] = first_default_color
866            result.append(trace_spec)
867
868    # Add trendline trace specifications
869    if "trendline" in args and args["trendline"]:
870        trace_spec = TraceSpec(
871            constructor=go.Scattergl if constructor == go.Scattergl else go.Scatter,
872            attrs=["trendline"],
873            trace_patch=dict(mode="lines"),
874            marginal=None,
875        )
876        if args["trendline_color_override"]:
877            trace_spec.trace_patch["line"] = dict(
878                color=args["trendline_color_override"]
879            )
880        result.append(trace_spec)
881    return result
882
883
884def one_group(x):
885    return ""
886
887
888def apply_default_cascade(args):
889    # first we apply px.defaults to unspecified args
890
891    for param in defaults.__slots__:
892        if param in args and args[param] is None:
893            args[param] = getattr(defaults, param)
894
895    # load the default template if set, otherwise "plotly"
896    if args["template"] is None:
897        if pio.templates.default is not None:
898            args["template"] = pio.templates.default
899        else:
900            args["template"] = "plotly"
901
902    try:
903        # retrieve the actual template if we were given a name
904        args["template"] = pio.templates[args["template"]]
905    except Exception:
906        # otherwise try to build a real template
907        args["template"] = go.layout.Template(args["template"])
908
909    # if colors not set explicitly or in px.defaults, defer to a template
910    # if the template doesn't have one, we set some final fallback defaults
911    if "color_continuous_scale" in args:
912        if (
913            args["color_continuous_scale"] is None
914            and args["template"].layout.colorscale.sequential
915        ):
916            args["color_continuous_scale"] = [
917                x[1] for x in args["template"].layout.colorscale.sequential
918            ]
919        if args["color_continuous_scale"] is None:
920            args["color_continuous_scale"] = sequential.Viridis
921
922    if "color_discrete_sequence" in args:
923        if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
924            args["color_discrete_sequence"] = args["template"].layout.colorway
925        if args["color_discrete_sequence"] is None:
926            args["color_discrete_sequence"] = qualitative.D3
927
928    # if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
929    # see if we can defer to template. If not, set reasonable defaults
930    if "symbol_sequence" in args:
931        if args["symbol_sequence"] is None and args["template"].data.scatter:
932            args["symbol_sequence"] = [
933                scatter.marker.symbol for scatter in args["template"].data.scatter
934            ]
935        if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
936            args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]
937
938    if "line_dash_sequence" in args:
939        if args["line_dash_sequence"] is None and args["template"].data.scatter:
940            args["line_dash_sequence"] = [
941                scatter.line.dash for scatter in args["template"].data.scatter
942            ]
943        if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
944            args["line_dash_sequence"] = [
945                "solid",
946                "dot",
947                "dash",
948                "longdash",
949                "dashdot",
950                "longdashdot",
951            ]
952
953
954def _check_name_not_reserved(field_name, reserved_names):
955    if field_name not in reserved_names:
956        return field_name
957    else:
958        raise NameError(
959            "A name conflict was encountered for argument '%s'. "
960            "A column or index with name '%s' is ambiguous." % (field_name, field_name)
961        )
962
963
964def _get_reserved_col_names(args):
965    """
966    This function builds a list of columns of the data_frame argument used
967    as arguments, either as str/int arguments or given as columns
968    (pandas series type).
969    """
970    df = args["data_frame"]
971    reserved_names = set()
972    for field in args:
973        if field not in all_attrables:
974            continue
975        names = args[field] if field in array_attrables else [args[field]]
976        if names is None:
977            continue
978        for arg in names:
979            if arg is None:
980                continue
981            elif isinstance(arg, str):  # no need to add ints since kw arg are not ints
982                reserved_names.add(arg)
983            elif isinstance(arg, pd.Series):
984                arg_name = arg.name
985                if arg_name and hasattr(df, arg_name):
986                    in_df = arg is df[arg_name]
987                    if in_df:
988                        reserved_names.add(arg_name)
989            elif arg is df.index and arg.name is not None:
990                reserved_names.add(arg.name)
991
992    return reserved_names
993
994
995def _is_col_list(df_input, arg):
996    """Returns True if arg looks like it's a list of columns or references to columns
997    in df_input, and False otherwise (in which case it's assumed to be a single column
998    or reference to a column).
999    """
1000    if arg is None or isinstance(arg, str) or isinstance(arg, int):
1001        return False
1002    if isinstance(arg, pd.MultiIndex):
1003        return False  # just to keep existing behaviour for now
1004    try:
1005        iter(arg)
1006    except TypeError:
1007        return False  # not iterable
1008    for c in arg:
1009        if isinstance(c, str) or isinstance(c, int):
1010            if df_input is None or c not in df_input.columns:
1011                return False
1012        else:
1013            try:
1014                iter(c)
1015            except TypeError:
1016                return False  # not iterable
1017    return True
1018
1019
1020def _isinstance_listlike(x):
1021    """Returns True if x is an iterable which can be transformed into a pandas Series,
1022    False for the other types of possible values of a `hover_data` dict.
1023    A tuple of length 2 is a special case corresponding to a (format, data) tuple.
1024    """
1025    if (
1026        isinstance(x, str)
1027        or (isinstance(x, tuple) and len(x) == 2)
1028        or isinstance(x, bool)
1029        or x is None
1030    ):
1031        return False
1032    else:
1033        return True
1034
1035
1036def _escape_col_name(df_input, col_name, extra):
1037    while df_input is not None and (col_name in df_input.columns or col_name in extra):
1038        col_name = "_" + col_name
1039    return col_name
1040
1041
1042def to_unindexed_series(x):
1043    """
1044    assuming x is list-like or even an existing pd.Series, return a new pd.Series with
1045    no index, without extracting the data from an existing Series via numpy, which
1046    seems to mangle datetime columns. Stripping the index from existing pd.Series is
1047    required to get things to match up right in the new DataFrame we're building
1048    """
1049    return pd.Series(x).reset_index(drop=True)
1050
1051
1052def process_args_into_dataframe(args, wide_mode, var_name, value_name):
1053    """
1054    After this function runs, the `all_attrables` keys of `args` all contain only
1055    references to columns of `df_output`. This function handles the extraction of data
1056    from `args["attrable"]` and column-name-generation as appropriate, and adds the
1057    data to `df_output` and then replaces `args["attrable"]` with the appropriate
1058    reference.
1059    """
1060
1061    df_input = args["data_frame"]
1062    df_provided = df_input is not None
1063
1064    df_output = pd.DataFrame()
1065    constants = dict()
1066    ranges = list()
1067    wide_id_vars = set()
1068    reserved_names = _get_reserved_col_names(args) if df_provided else set()
1069
1070    # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords
1071    if "dimensions" in args and args["dimensions"] is None:
1072        if not df_provided:
1073            raise ValueError(
1074                "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument."
1075            )
1076        else:
1077            df_output[df_input.columns] = df_input[df_input.columns]
1078
1079    # hover_data is a dict
1080    hover_data_is_dict = (
1081        "hover_data" in args
1082        and args["hover_data"]
1083        and isinstance(args["hover_data"], dict)
1084    )
1085    # If dict, convert all values of hover_data to tuples to simplify processing
1086    if hover_data_is_dict:
1087        for k in args["hover_data"]:
1088            if _isinstance_listlike(args["hover_data"][k]):
1089                args["hover_data"][k] = (True, args["hover_data"][k])
1090            if not isinstance(args["hover_data"][k], tuple):
1091                args["hover_data"][k] = (args["hover_data"][k], None)
1092            if df_provided and args["hover_data"][k][1] is not None and k in df_input:
1093                raise ValueError(
1094                    "Ambiguous input: values for '%s' appear both in hover_data and data_frame"
1095                    % k
1096                )
1097    # Loop over possible arguments
1098    for field_name in all_attrables:
1099        # Massaging variables
1100        argument_list = (
1101            [args.get(field_name)]
1102            if field_name not in array_attrables
1103            else args.get(field_name)
1104        )
1105        # argument not specified, continue
1106        if argument_list is None or argument_list is [None]:
1107            continue
1108        # Argument name: field_name if the argument is not a list
1109        # Else we give names like ["hover_data_0, hover_data_1"] etc.
1110        field_list = (
1111            [field_name]
1112            if field_name not in array_attrables
1113            else [field_name + "_" + str(i) for i in range(len(argument_list))]
1114        )
1115        # argument_list and field_list ready, iterate over them
1116        # Core of the loop starts here
1117        for i, (argument, field) in enumerate(zip(argument_list, field_list)):
1118            length = len(df_output)
1119            if argument is None:
1120                continue
1121            col_name = None
1122            # Case of multiindex
1123            if isinstance(argument, pd.MultiIndex):
1124                raise TypeError(
1125                    "Argument '%s' is a pandas MultiIndex. "
1126                    "pandas MultiIndex is not supported by plotly express "
1127                    "at the moment." % field
1128                )
1129            # ----------------- argument is a special value ----------------------
1130            if isinstance(argument, Constant) or isinstance(argument, Range):
1131                col_name = _check_name_not_reserved(
1132                    str(argument.label) if argument.label is not None else field,
1133                    reserved_names,
1134                )
1135                if isinstance(argument, Constant):
1136                    constants[col_name] = argument.value
1137                else:
1138                    ranges.append(col_name)
1139            # ----------------- argument is likely a col name ----------------------
1140            elif isinstance(argument, str) or not hasattr(argument, "__len__"):
1141                if (
1142                    field_name == "hover_data"
1143                    and hover_data_is_dict
1144                    and args["hover_data"][str(argument)][1] is not None
1145                ):
1146                    # hover_data has onboard data
1147                    # previously-checked to have no name-conflict with data_frame
1148                    col_name = str(argument)
1149                    real_argument = args["hover_data"][col_name][1]
1150
1151                    if length and len(real_argument) != length:
1152                        raise ValueError(
1153                            "All arguments should have the same length. "
1154                            "The length of hover_data key `%s` is %d, whereas the "
1155                            "length of previously-processed arguments %s is %d"
1156                            % (
1157                                argument,
1158                                len(real_argument),
1159                                str(list(df_output.columns)),
1160                                length,
1161                            )
1162                        )
1163                    df_output[col_name] = to_unindexed_series(real_argument)
1164                elif not df_provided:
1165                    raise ValueError(
1166                        "String or int arguments are only possible when a "
1167                        "DataFrame or an array is provided in the `data_frame` "
1168                        "argument. No DataFrame was provided, but argument "
1169                        "'%s' is of type str or int." % field
1170                    )
1171                # Check validity of column name
1172                elif argument not in df_input.columns:
1173                    if wide_mode and argument in (value_name, var_name):
1174                        continue
1175                    else:
1176                        err_msg = (
1177                            "Value of '%s' is not the name of a column in 'data_frame'. "
1178                            "Expected one of %s but received: %s"
1179                            % (field, str(list(df_input.columns)), argument)
1180                        )
1181                        if argument == "index":
1182                            err_msg += "\n To use the index, pass it in directly as `df.index`."
1183                        raise ValueError(err_msg)
1184                elif length and len(df_input[argument]) != length:
1185                    raise ValueError(
1186                        "All arguments should have the same length. "
1187                        "The length of column argument `df[%s]` is %d, whereas the "
1188                        "length of  previously-processed arguments %s is %d"
1189                        % (
1190                            field,
1191                            len(df_input[argument]),
1192                            str(list(df_output.columns)),
1193                            length,
1194                        )
1195                    )
1196                else:
1197                    col_name = str(argument)
1198                    df_output[col_name] = to_unindexed_series(df_input[argument])
1199            # ----------------- argument is likely a column / array / list.... -------
1200            else:
1201                if df_provided and hasattr(argument, "name"):
1202                    if argument is df_input.index:
1203                        if argument.name is None or argument.name in df_input:
1204                            col_name = "index"
1205                        else:
1206                            col_name = argument.name
1207                        col_name = _escape_col_name(
1208                            df_input, col_name, [var_name, value_name]
1209                        )
1210                    else:
1211                        if (
1212                            argument.name is not None
1213                            and argument.name in df_input
1214                            and argument is df_input[argument.name]
1215                        ):
1216                            col_name = argument.name
1217                if col_name is None:  # numpy array, list...
1218                    col_name = _check_name_not_reserved(field, reserved_names)
1219
1220                if length and len(argument) != length:
1221                    raise ValueError(
1222                        "All arguments should have the same length. "
1223                        "The length of argument `%s` is %d, whereas the "
1224                        "length of  previously-processed arguments %s is %d"
1225                        % (field, len(argument), str(list(df_output.columns)), length)
1226                    )
1227                df_output[str(col_name)] = to_unindexed_series(argument)
1228
1229            # Finally, update argument with column name now that column exists
1230            assert col_name is not None, (
1231                "Data-frame processing failure, likely due to a internal bug. "
1232                "Please report this to "
1233                "https://github.com/plotly/plotly.py/issues/new and we will try to "
1234                "replicate and fix it."
1235            )
1236            if field_name not in array_attrables:
1237                args[field_name] = str(col_name)
1238            elif isinstance(args[field_name], dict):
1239                pass
1240            else:
1241                args[field_name][i] = str(col_name)
1242            if field_name != "wide_variable":
1243                wide_id_vars.add(str(col_name))
1244
1245    for col_name in ranges:
1246        df_output[col_name] = range(len(df_output))
1247
1248    for col_name in constants:
1249        df_output[col_name] = constants[col_name]
1250
1251    return df_output, wide_id_vars
1252
1253
1254def build_dataframe(args, constructor):
1255    """
1256    Constructs a dataframe and modifies `args` in-place.
1257
1258    The argument values in `args` can be either strings corresponding to
1259    existing columns of a dataframe, or data arrays (lists, numpy arrays,
1260    pandas columns, series).
1261
1262    Parameters
1263    ----------
1264    args : OrderedDict
1265        arguments passed to the px function and subsequently modified
1266    constructor : graph_object trace class
1267        the trace type selected for this figure
1268    """
1269
1270    # make copies of all the fields via dict() and list()
1271    for field in args:
1272        if field in array_attrables and args[field] is not None:
1273            args[field] = (
1274                dict(args[field])
1275                if isinstance(args[field], dict)
1276                else list(args[field])
1277            )
1278
1279    # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
1280    df_provided = args["data_frame"] is not None
1281    if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
1282        args["data_frame"] = pd.DataFrame(args["data_frame"])
1283    df_input = args["data_frame"]
1284
1285    # now we handle special cases like wide-mode or x-xor-y specification
1286    # by rearranging args to tee things up for process_args_into_dataframe to work
1287    no_x = args.get("x", None) is None
1288    no_y = args.get("y", None) is None
1289    wide_x = False if no_x else _is_col_list(df_input, args["x"])
1290    wide_y = False if no_y else _is_col_list(df_input, args["y"])
1291
1292    wide_mode = False
1293    var_name = None  # will likely be "variable" in wide_mode
1294    wide_cross_name = None  # will likely be "index" in wide_mode
1295    value_name = None  # will likely be "value" in wide_mode
1296    hist2d_types = [go.Histogram2d, go.Histogram2dContour]
1297    if constructor in cartesians:
1298        if wide_x and wide_y:
1299            raise ValueError(
1300                "Cannot accept list of column references or list of columns for both `x` and `y`."
1301            )
1302        if df_provided and no_x and no_y:
1303            wide_mode = True
1304            if isinstance(df_input.columns, pd.MultiIndex):
1305                raise TypeError(
1306                    "Data frame columns is a pandas MultiIndex. "
1307                    "pandas MultiIndex is not supported by plotly express "
1308                    "at the moment."
1309                )
1310            args["wide_variable"] = list(df_input.columns)
1311            var_name = df_input.columns.name
1312            if var_name in [None, "value", "index"] or var_name in df_input:
1313                var_name = "variable"
1314            if constructor == go.Funnel:
1315                wide_orientation = args.get("orientation", None) or "h"
1316            else:
1317                wide_orientation = args.get("orientation", None) or "v"
1318            args["orientation"] = wide_orientation
1319            args["wide_cross"] = None
1320        elif wide_x != wide_y:
1321            wide_mode = True
1322            args["wide_variable"] = args["y"] if wide_y else args["x"]
1323            if df_provided and args["wide_variable"] is df_input.columns:
1324                var_name = df_input.columns.name
1325            if isinstance(args["wide_variable"], pd.Index):
1326                args["wide_variable"] = list(args["wide_variable"])
1327            if var_name in [None, "value", "index"] or (
1328                df_provided and var_name in df_input
1329            ):
1330                var_name = "variable"
1331            if constructor == go.Histogram:
1332                wide_orientation = "v" if wide_x else "h"
1333            else:
1334                wide_orientation = "v" if wide_y else "h"
1335            args["y" if wide_y else "x"] = None
1336            args["wide_cross"] = None
1337            if not no_x and not no_y:
1338                wide_cross_name = "__x__" if wide_y else "__y__"
1339
1340    if wide_mode:
1341        value_name = _escape_col_name(df_input, "value", [])
1342        var_name = _escape_col_name(df_input, var_name, [])
1343
1344    missing_bar_dim = None
1345    if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types:
1346        if not wide_mode and (no_x != no_y):
1347            for ax in ["x", "y"]:
1348                if args.get(ax, None) is None:
1349                    args[ax] = df_input.index if df_provided else Range()
1350                    if constructor == go.Bar:
1351                        missing_bar_dim = ax
1352                    else:
1353                        if args["orientation"] is None:
1354                            args["orientation"] = "v" if ax == "x" else "h"
1355        if wide_mode and wide_cross_name is None:
1356            if no_x != no_y and args["orientation"] is None:
1357                args["orientation"] = "v" if no_x else "h"
1358            if df_provided:
1359                if isinstance(df_input.index, pd.MultiIndex):
1360                    raise TypeError(
1361                        "Data frame index is a pandas MultiIndex. "
1362                        "pandas MultiIndex is not supported by plotly express "
1363                        "at the moment."
1364                    )
1365                args["wide_cross"] = df_input.index
1366            else:
1367                args["wide_cross"] = Range(
1368                    label=_escape_col_name(df_input, "index", [var_name, value_name])
1369                )
1370
1371    no_color = False
1372    if type(args.get("color", None)) == str and args["color"] == NO_COLOR:
1373        no_color = True
1374        args["color"] = None
1375    # now that things have been prepped, we do the systematic rewriting of `args`
1376
1377    df_output, wide_id_vars = process_args_into_dataframe(
1378        args, wide_mode, var_name, value_name
1379    )
1380
1381    # now that `df_output` exists and `args` contains only references, we complete
1382    # the special-case and wide-mode handling by further rewriting args and/or mutating
1383    # df_output
1384
1385    count_name = _escape_col_name(df_output, "count", [var_name, value_name])
1386    if not wide_mode and missing_bar_dim and constructor == go.Bar:
1387        # now that we've populated df_output, we check to see if the non-missing
1388        # dimension is categorical: if so, then setting the missing dimension to a
1389        # constant 1 is a less-insane thing to do than setting it to the index by
1390        # default and we let the normal auto-orientation-code do its thing later
1391        other_dim = "x" if missing_bar_dim == "y" else "y"
1392        if not _is_continuous(df_output, args[other_dim]):
1393            args[missing_bar_dim] = count_name
1394            df_output[count_name] = 1
1395        else:
1396            # on the other hand, if the non-missing dimension is continuous, then we
1397            # can use this information to override the normal auto-orientation code
1398            if args["orientation"] is None:
1399                args["orientation"] = "v" if missing_bar_dim == "x" else "h"
1400
1401    if constructor in hist2d_types:
1402        del args["orientation"]
1403
1404    if wide_mode:
1405        # at this point, `df_output` is semi-long/semi-wide, but we know which columns
1406        # are which, so we melt it and reassign `args` to refer to the newly-tidy
1407        # columns, keeping track of various names and manglings set up above
1408        wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars]
1409        del args["wide_variable"]
1410        if wide_cross_name == "__x__":
1411            wide_cross_name = args["x"]
1412        elif wide_cross_name == "__y__":
1413            wide_cross_name = args["y"]
1414        else:
1415            wide_cross_name = args["wide_cross"]
1416        del args["wide_cross"]
1417        dtype = None
1418        for v in wide_value_vars:
1419            v_dtype = df_output[v].dtype.kind
1420            v_dtype = "number" if v_dtype in ["i", "f", "u"] else v_dtype
1421            if dtype is None:
1422                dtype = v_dtype
1423            elif dtype != v_dtype:
1424                raise ValueError(
1425                    "Plotly Express cannot process wide-form data with columns of different type."
1426                )
1427        df_output = df_output.melt(
1428            id_vars=wide_id_vars,
1429            value_vars=wide_value_vars,
1430            var_name=var_name,
1431            value_name=value_name,
1432        )
1433        assert len(df_output.columns) == len(set(df_output.columns)), (
1434            "Wide-mode name-inference failure, likely due to a internal bug. "
1435            "Please report this to "
1436            "https://github.com/plotly/plotly.py/issues/new and we will try to "
1437            "replicate and fix it."
1438        )
1439        df_output[var_name] = df_output[var_name].astype(str)
1440        orient_v = wide_orientation == "v"
1441
1442        if constructor in [go.Scatter, go.Funnel] + hist2d_types:
1443            args["x" if orient_v else "y"] = wide_cross_name
1444            args["y" if orient_v else "x"] = value_name
1445            if constructor != go.Histogram2d:
1446                args["color"] = args["color"] or var_name
1447            if "line_group" in args:
1448                args["line_group"] = args["line_group"] or var_name
1449        if constructor == go.Bar:
1450            if _is_continuous(df_output, value_name):
1451                args["x" if orient_v else "y"] = wide_cross_name
1452                args["y" if orient_v else "x"] = value_name
1453                args["color"] = args["color"] or var_name
1454            else:
1455                args["x" if orient_v else "y"] = value_name
1456                args["y" if orient_v else "x"] = count_name
1457                df_output[count_name] = 1
1458                args["color"] = args["color"] or var_name
1459        if constructor in [go.Violin, go.Box]:
1460            args["x" if orient_v else "y"] = wide_cross_name or var_name
1461            args["y" if orient_v else "x"] = value_name
1462        if constructor == go.Histogram:
1463            args["x" if orient_v else "y"] = value_name
1464            args["y" if orient_v else "x"] = wide_cross_name
1465            args["color"] = args["color"] or var_name
1466    if no_color:
1467        args["color"] = None
1468    args["data_frame"] = df_output
1469    return args
1470
1471
1472def _check_dataframe_all_leaves(df):
1473    df_sorted = df.sort_values(by=list(df.columns))
1474    null_mask = df_sorted.isnull()
1475    df_sorted = df_sorted.astype(str)
1476    null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
1477    for null_row_index in null_indices:
1478        row = null_mask.iloc[null_row_index]
1479        i = np.nonzero(row.values)[0][0]
1480        if not row[i:].all():
1481            raise ValueError(
1482                "None entries cannot have not-None children",
1483                df_sorted.iloc[null_row_index],
1484            )
1485    df_sorted[null_mask] = ""
1486    row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1))
1487    for i, row in enumerate(row_strings[:-1]):
1488        if row_strings[i + 1] in row and (i + 1) in null_indices:
1489            raise ValueError(
1490                "Non-leaves rows are not permitted in the dataframe \n",
1491                df_sorted.iloc[i + 1],
1492                "is not a leaf.",
1493            )
1494
1495
1496def process_dataframe_hierarchy(args):
1497    """
1498    Build dataframe for sunburst or treemap when the path argument is provided.
1499    """
1500    df = args["data_frame"]
1501    path = args["path"][::-1]
1502    _check_dataframe_all_leaves(df[path[::-1]])
1503    discrete_color = False
1504
1505    new_path = []
1506    for col_name in path:
1507        new_col_name = col_name + "_path_copy"
1508        new_path.append(new_col_name)
1509        df[new_col_name] = df[col_name]
1510    path = new_path
1511    # ------------ Define aggregation functions --------------------------------
1512
1513    def aggfunc_discrete(x):
1514        uniques = x.unique()
1515        if len(uniques) == 1:
1516            return uniques[0]
1517        else:
1518            return "(?)"
1519
1520    agg_f = {}
1521    aggfunc_color = None
1522    if args["values"]:
1523        try:
1524            df[args["values"]] = pd.to_numeric(df[args["values"]])
1525        except ValueError:
1526            raise ValueError(
1527                "Column `%s` of `df` could not be converted to a numerical data type."
1528                % args["values"]
1529            )
1530
1531        if args["color"]:
1532            if args["color"] == args["values"]:
1533                new_value_col_name = args["values"] + "_sum"
1534                df[new_value_col_name] = df[args["values"]]
1535                args["values"] = new_value_col_name
1536        count_colname = args["values"]
1537    else:
1538        # we need a count column for the first groupby and the weighted mean of color
1539        # trick to be sure the col name is unused: take the sum of existing names
1540        count_colname = (
1541            "count"
1542            if "count" not in df.columns
1543            else "".join([str(el) for el in list(df.columns)])
1544        )
1545        # we can modify df because it's a copy of the px argument
1546        df[count_colname] = 1
1547        args["values"] = count_colname
1548    agg_f[count_colname] = "sum"
1549
1550    if args["color"]:
1551        if not _is_continuous(df, args["color"]):
1552            aggfunc_color = aggfunc_discrete
1553            discrete_color = True
1554        else:
1555
1556            def aggfunc_continuous(x):
1557                return np.average(x, weights=df.loc[x.index, count_colname])
1558
1559            aggfunc_color = aggfunc_continuous
1560        agg_f[args["color"]] = aggfunc_color
1561
1562    #  Other columns (for color, hover_data, custom_data etc.)
1563    cols = list(set(df.columns).difference(path))
1564    for col in cols:  # for hover_data, custom_data etc.
1565        if col not in agg_f:
1566            agg_f[col] = aggfunc_discrete
1567    # Avoid collisions with reserved names - columns in the path have been copied already
1568    cols = list(set(cols) - set(["labels", "parent", "id"]))
1569    # ----------------------------------------------------------------------------
1570    df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols)
1571    #  Set column type here (useful for continuous vs discrete colorscale)
1572    for col in cols:
1573        df_all_trees[col] = df_all_trees[col].astype(df[col].dtype)
1574    for i, level in enumerate(path):
1575        df_tree = pd.DataFrame(columns=df_all_trees.columns)
1576        dfg = df.groupby(path[i:]).agg(agg_f)
1577        dfg = dfg.reset_index()
1578        # Path label massaging
1579        df_tree["labels"] = dfg[level].copy().astype(str)
1580        df_tree["parent"] = ""
1581        df_tree["id"] = dfg[level].copy().astype(str)
1582        if i < len(path) - 1:
1583            j = i + 1
1584            while j < len(path):
1585                df_tree["parent"] = (
1586                    dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"]
1587                )
1588                df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"]
1589                j += 1
1590
1591        df_tree["parent"] = df_tree["parent"].str.rstrip("/")
1592        if cols:
1593            df_tree[cols] = dfg[cols]
1594        df_all_trees = df_all_trees.append(df_tree, ignore_index=True)
1595
1596    # we want to make sure than (?) is the first color of the sequence
1597    if args["color"] and discrete_color:
1598        sort_col_name = "sort_color_if_discrete_color"
1599        while sort_col_name in df_all_trees.columns:
1600            sort_col_name += "0"
1601        df_all_trees[sort_col_name] = df[args["color"]].astype(str)
1602        df_all_trees = df_all_trees.sort_values(by=sort_col_name)
1603
1604    # Now modify arguments
1605    args["data_frame"] = df_all_trees
1606    args["path"] = None
1607    args["ids"] = "id"
1608    args["names"] = "labels"
1609    args["parents"] = "parent"
1610    if args["color"]:
1611        if not args["hover_data"]:
1612            args["hover_data"] = [args["color"]]
1613        elif isinstance(args["hover_data"], dict):
1614            if not args["hover_data"].get(args["color"]):
1615                args["hover_data"][args["color"]] = (True, None)
1616        else:
1617            args["hover_data"].append(args["color"])
1618    return args
1619
1620
1621def process_dataframe_timeline(args):
1622    """
1623    Massage input for bar traces for px.timeline()
1624    """
1625    args["is_timeline"] = True
1626    if args["x_start"] is None or args["x_end"] is None:
1627        raise ValueError("Both x_start and x_end are required")
1628
1629    try:
1630        x_start = pd.to_datetime(args["data_frame"][args["x_start"]])
1631        x_end = pd.to_datetime(args["data_frame"][args["x_end"]])
1632    except (ValueError, TypeError):
1633        raise TypeError(
1634            "Both x_start and x_end must refer to data convertible to datetimes."
1635        )
1636
1637    # note that we are not adding any columns to the data frame here, so no risk of overwrite
1638    args["data_frame"][args["x_end"]] = (x_end - x_start).astype("timedelta64[ms]")
1639    args["x"] = args["x_end"]
1640    del args["x_end"]
1641    args["base"] = args["x_start"]
1642    del args["x_start"]
1643    return args
1644
1645
1646def infer_config(args, constructor, trace_patch, layout_patch):
1647    attrs = [k for k in direct_attrables + array_attrables if k in args]
1648    grouped_attrs = []
1649
1650    # Compute sizeref
1651    sizeref = 0
1652    if "size" in args and args["size"]:
1653        sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2
1654
1655    # Compute color attributes and grouping attributes
1656    if "color" in args:
1657        if "color_continuous_scale" in args:
1658            if "color_discrete_sequence" not in args:
1659                attrs.append("color")
1660            else:
1661                if args["color"] and _is_continuous(args["data_frame"], args["color"]):
1662                    attrs.append("color")
1663                    args["color_is_continuous"] = True
1664                elif constructor in [go.Sunburst, go.Treemap]:
1665                    attrs.append("color")
1666                    args["color_is_continuous"] = False
1667                else:
1668                    grouped_attrs.append("marker.color")
1669        elif "line_group" in args or constructor == go.Histogram2dContour:
1670            grouped_attrs.append("line.color")
1671        elif constructor in [go.Pie, go.Funnelarea]:
1672            attrs.append("color")
1673            if args["color"]:
1674                if args["hover_data"] is None:
1675                    args["hover_data"] = []
1676                args["hover_data"].append(args["color"])
1677        else:
1678            grouped_attrs.append("marker.color")
1679
1680        show_colorbar = bool(
1681            "color" in attrs
1682            and args["color"]
1683            and constructor not in [go.Pie, go.Funnelarea]
1684            and (
1685                constructor not in [go.Treemap, go.Sunburst]
1686                or args.get("color_is_continuous")
1687            )
1688        )
1689    else:
1690        show_colorbar = False
1691
1692    # Compute line_dash grouping attribute
1693    if "line_dash" in args:
1694        grouped_attrs.append("line.dash")
1695
1696    # Compute symbol grouping attribute
1697    if "symbol" in args:
1698        grouped_attrs.append("marker.symbol")
1699
1700    if "orientation" in args:
1701        has_x = args["x"] is not None
1702        has_y = args["y"] is not None
1703        if args["orientation"] is None:
1704            if constructor in [go.Histogram, go.Scatter]:
1705                if has_y and not has_x:
1706                    args["orientation"] = "h"
1707            elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]:
1708                if has_x and not has_y:
1709                    args["orientation"] = "h"
1710
1711        if args["orientation"] is None and has_x and has_y:
1712            x_is_continuous = _is_continuous(args["data_frame"], args["x"])
1713            y_is_continuous = _is_continuous(args["data_frame"], args["y"])
1714            if x_is_continuous and not y_is_continuous:
1715                args["orientation"] = "h"
1716            if y_is_continuous and not x_is_continuous:
1717                args["orientation"] = "v"
1718
1719        if args["orientation"] is None:
1720            args["orientation"] = "v"
1721
1722        if constructor == go.Histogram:
1723            if has_x and has_y and args["histfunc"] is None:
1724                args["histfunc"] = trace_patch["histfunc"] = "sum"
1725
1726            orientation = args["orientation"]
1727            nbins = args["nbins"]
1728            trace_patch["nbinsx"] = nbins if orientation == "v" else None
1729            trace_patch["nbinsy"] = None if orientation == "v" else nbins
1730            trace_patch["bingroup"] = "x" if orientation == "v" else "y"
1731        trace_patch["orientation"] = args["orientation"]
1732
1733        if constructor in [go.Violin, go.Box]:
1734            mode = "boxmode" if constructor == go.Box else "violinmode"
1735            if layout_patch[mode] is None and args["color"] is not None:
1736                if args["y"] == args["color"] and args["orientation"] == "h":
1737                    layout_patch[mode] = "overlay"
1738                elif args["x"] == args["color"] and args["orientation"] == "v":
1739                    layout_patch[mode] = "overlay"
1740            if layout_patch[mode] is None:
1741                layout_patch[mode] = "group"
1742
1743    if (
1744        constructor == go.Histogram2d
1745        and args["z"] is not None
1746        and args["histfunc"] is None
1747    ):
1748        args["histfunc"] = trace_patch["histfunc"] = "sum"
1749
1750    if constructor in [go.Histogram2d, go.Densitymapbox]:
1751        show_colorbar = True
1752        trace_patch["coloraxis"] = "coloraxis1"
1753
1754    if "opacity" in args:
1755        if args["opacity"] is None:
1756            if "barmode" in args and args["barmode"] == "overlay":
1757                trace_patch["marker"] = dict(opacity=0.5)
1758        elif constructor in [go.Densitymapbox, go.Pie, go.Funnel, go.Funnelarea]:
1759            trace_patch["opacity"] = args["opacity"]
1760        else:
1761            trace_patch["marker"] = dict(opacity=args["opacity"])
1762    if "line_group" in args:
1763        trace_patch["mode"] = "lines" + ("+markers+text" if args["text"] else "")
1764    elif constructor != go.Splom and (
1765        "symbol" in args or constructor == go.Scattermapbox
1766    ):
1767        trace_patch["mode"] = "markers" + ("+text" if args["text"] else "")
1768
1769    if "line_shape" in args:
1770        trace_patch["line"] = dict(shape=args["line_shape"])
1771
1772    if "geojson" in args:
1773        trace_patch["featureidkey"] = args["featureidkey"]
1774        trace_patch["geojson"] = (
1775            args["geojson"]
1776            if not hasattr(args["geojson"], "__geo_interface__")  # for geopandas
1777            else args["geojson"].__geo_interface__
1778        )
1779
1780    # Compute marginal attribute
1781    if "marginal" in args:
1782        position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
1783        other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y"
1784        args[position] = args["marginal"]
1785        args[other_position] = None
1786
1787    # If both marginals and faceting are specified, faceting wins
1788    if args.get("facet_col", None) is not None and args.get("marginal_y", None):
1789        args["marginal_y"] = None
1790
1791    if args.get("facet_row", None) is not None and args.get("marginal_x", None):
1792        args["marginal_x"] = None
1793
1794    # facet_col_wrap only works if no marginals or row faceting is used
1795    if (
1796        args.get("marginal_x", None) is not None
1797        or args.get("marginal_y", None) is not None
1798        or args.get("facet_row", None) is not None
1799    ):
1800        args["facet_col_wrap"] = 0
1801
1802    # Compute applicable grouping attributes
1803    for k in group_attrables:
1804        if k in args:
1805            grouped_attrs.append(k)
1806
1807    # Create grouped mappings
1808    grouped_mappings = [make_mapping(args, a) for a in grouped_attrs]
1809
1810    # Create trace specs
1811    trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
1812    return trace_specs, grouped_mappings, sizeref, show_colorbar
1813
1814
1815def get_orderings(args, grouper, grouped):
1816    """
1817    `orders` is the user-supplied ordering (with the remaining data-frame-supplied
1818    ordering appended if the column is used for grouping). It includes anything the user
1819    gave, for any variable, including values not present in the dataset. It is used
1820    downstream to set e.g. `categoryarray` for cartesian axes
1821
1822    `group_names` is the set of groups, ordered by the order above
1823
1824    `group_values` is a subset of `orders` in both keys and values. It contains a key
1825     for every grouped mapping and its values are the sorted *data* values for these
1826     mappings.
1827    """
1828    orders = {} if "category_orders" not in args else args["category_orders"].copy()
1829    group_names = []
1830    group_values = {}
1831    for group_name in grouped.groups:
1832        if len(grouper) == 1:
1833            group_name = (group_name,)
1834        group_names.append(group_name)
1835        for col in grouper:
1836            if col != one_group:
1837                uniques = args["data_frame"][col].unique()
1838                if col not in orders:
1839                    orders[col] = list(uniques)
1840                else:
1841                    for val in uniques:
1842                        if val not in orders[col]:
1843                            orders[col].append(val)
1844                group_values[col] = sorted(uniques, key=orders[col].index)
1845
1846    for i, col in reversed(list(enumerate(grouper))):
1847        if col != one_group:
1848            group_names = sorted(
1849                group_names,
1850                key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
1851            )
1852
1853    return orders, group_names, group_values
1854
1855
1856def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1857    trace_patch = trace_patch or {}
1858    layout_patch = layout_patch or {}
1859    apply_default_cascade(args)
1860
1861    args = build_dataframe(args, constructor)
1862    if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None:
1863        args = process_dataframe_hierarchy(args)
1864    if constructor == "timeline":
1865        constructor = go.Bar
1866        args = process_dataframe_timeline(args)
1867
1868    trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
1869        args, constructor, trace_patch, layout_patch
1870    )
1871    grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
1872    grouped = args["data_frame"].groupby(grouper, sort=False)
1873
1874    orders, sorted_group_names, sorted_group_values = get_orderings(
1875        args, grouper, grouped
1876    )
1877
1878    col_labels = []
1879    row_labels = []
1880
1881    for m in grouped_mappings:
1882        if m.grouper:
1883            if m.facet == "col":
1884                prefix = get_label(args, args["facet_col"]) + "="
1885                col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1886            if m.facet == "row":
1887                prefix = get_label(args, args["facet_row"]) + "="
1888                row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1889            for val in sorted_group_values[m.grouper]:
1890                if val not in m.val_map:
1891                    m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
1892
1893    subplot_type = _subplot_type_for_trace_type(constructor().type)
1894
1895    trace_names_by_frame = {}
1896    frames = OrderedDict()
1897    trendline_rows = []
1898    nrows = ncols = 1
1899    trace_name_labels = None
1900    for group_name in sorted_group_names:
1901        group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
1902        mapping_labels = OrderedDict()
1903        trace_name_labels = OrderedDict()
1904        frame_name = ""
1905        for col, val, m in zip(grouper, group_name, grouped_mappings):
1906            if col != one_group:
1907                key = get_label(args, col)
1908                if not isinstance(m.val_map, IdentityMap):
1909                    mapping_labels[key] = str(val)
1910                    if m.show_in_trace_name:
1911                        trace_name_labels[key] = str(val)
1912                if m.variable == "animation_frame":
1913                    frame_name = val
1914        trace_name = ", ".join(trace_name_labels.values())
1915        if frame_name not in trace_names_by_frame:
1916            trace_names_by_frame[frame_name] = set()
1917        trace_names = trace_names_by_frame[frame_name]
1918
1919        for trace_spec in trace_specs:
1920            # Create the trace
1921            trace = trace_spec.constructor(name=trace_name)
1922            if trace_spec.constructor not in [
1923                go.Parcats,
1924                go.Parcoords,
1925                go.Choropleth,
1926                go.Choroplethmapbox,
1927                go.Densitymapbox,
1928                go.Histogram2d,
1929                go.Sunburst,
1930                go.Treemap,
1931            ]:
1932                trace.update(
1933                    legendgroup=trace_name,
1934                    showlegend=(trace_name != "" and trace_name not in trace_names),
1935                )
1936            if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]:
1937                trace.update(alignmentgroup=True, offsetgroup=trace_name)
1938            trace_names.add(trace_name)
1939
1940            # Init subplot row/col
1941            trace._subplot_row = 1
1942            trace._subplot_col = 1
1943
1944            for i, m in enumerate(grouped_mappings):
1945                val = group_name[i]
1946                if val not in m.val_map:
1947                    m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
1948                try:
1949                    m.updater(trace, m.val_map[val])  # covers most cases
1950                except ValueError:
1951                    # this catches some odd cases like marginals
1952                    if (
1953                        trace_spec != trace_specs[0]
1954                        and trace_spec.constructor in [go.Violin, go.Box, go.Histogram]
1955                        and m.variable == "symbol"
1956                    ):
1957                        pass
1958                    elif (
1959                        trace_spec != trace_specs[0]
1960                        and trace_spec.constructor in [go.Histogram]
1961                        and m.variable == "color"
1962                    ):
1963                        trace.update(marker=dict(color=m.val_map[val]))
1964                    elif (
1965                        trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]
1966                        and m.variable == "color"
1967                    ):
1968                        trace.update(
1969                            z=[1] * len(group),
1970                            colorscale=[m.val_map[val]] * 2,
1971                            showscale=False,
1972                            showlegend=True,
1973                        )
1974                    else:
1975                        raise
1976
1977                # Find row for trace, handling facet_row and marginal_x
1978                if m.facet == "row":
1979                    row = m.val_map[val]
1980                else:
1981                    if (
1982                        bool(args.get("marginal_x", False))
1983                        and trace_spec.marginal != "x"
1984                    ):
1985                        row = 2
1986                    else:
1987                        row = 1
1988
1989                facet_col_wrap = args.get("facet_col_wrap", 0)
1990                # Find col for trace, handling facet_col and marginal_y
1991                if m.facet == "col":
1992                    col = m.val_map[val]
1993                    if facet_col_wrap:  # assumes no facet_row, no marginals
1994                        row = 1 + ((col - 1) // facet_col_wrap)
1995                        col = 1 + ((col - 1) % facet_col_wrap)
1996                else:
1997                    if trace_spec.marginal == "y":
1998                        col = 2
1999                    else:
2000                        col = 1
2001
2002                nrows = max(nrows, row)
2003                if row > 1:
2004                    trace._subplot_row = row
2005
2006                ncols = max(ncols, col)
2007                if col > 1:
2008                    trace._subplot_col = col
2009            if (
2010                trace_specs[0].constructor == go.Histogram2dContour
2011                and trace_spec.constructor == go.Box
2012                and trace.line.color
2013            ):
2014                trace.update(marker=dict(color=trace.line.color))
2015
2016            patch, fit_results = make_trace_kwargs(
2017                args, trace_spec, group, mapping_labels.copy(), sizeref
2018            )
2019            trace.update(patch)
2020            if fit_results is not None:
2021                trendline_rows.append(mapping_labels.copy())
2022                trendline_rows[-1]["px_fit_results"] = fit_results
2023            if frame_name not in frames:
2024                frames[frame_name] = dict(data=[], name=frame_name)
2025            frames[frame_name]["data"].append(trace)
2026    frame_list = [f for f in frames.values()]
2027    if len(frame_list) > 1:
2028        frame_list = sorted(
2029            frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"])
2030        )
2031
2032    if show_colorbar:
2033        colorvar = "z" if constructor in [go.Histogram2d, go.Densitymapbox] else "color"
2034        range_color = args["range_color"] or [None, None]
2035
2036        colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
2037        layout_patch["coloraxis1"] = dict(
2038            colorscale=colorscale_validator.validate_coerce(
2039                args["color_continuous_scale"]
2040            ),
2041            cmid=args["color_continuous_midpoint"],
2042            cmin=range_color[0],
2043            cmax=range_color[1],
2044            colorbar=dict(
2045                title_text=get_decorated_label(args, args[colorvar], colorvar)
2046            ),
2047        )
2048    for v in ["height", "width"]:
2049        if args[v]:
2050            layout_patch[v] = args[v]
2051    layout_patch["legend"] = dict(tracegroupgap=0)
2052    if trace_name_labels:
2053        layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels)
2054    if args["title"]:
2055        layout_patch["title_text"] = args["title"]
2056    elif args["template"].layout.margin.t is None:
2057        layout_patch["margin"] = {"t": 60}
2058    if (
2059        "size" in args
2060        and args["size"]
2061        and args["template"].layout.legend.itemsizing is None
2062    ):
2063        layout_patch["legend"]["itemsizing"] = "constant"
2064
2065    fig = init_figure(
2066        args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
2067    )
2068
2069    # Position traces in subplots
2070    for frame in frame_list:
2071        for trace in frame["data"]:
2072            if isinstance(trace, go.Splom):
2073                # Special case that is not compatible with make_subplots
2074                continue
2075
2076            _set_trace_grid_reference(
2077                trace,
2078                fig.layout,
2079                fig._grid_ref,
2080                nrows - trace._subplot_row + 1,
2081                trace._subplot_col,
2082            )
2083
2084    # Add traces, layout and frames to figure
2085    fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
2086    fig.update_layout(layout_patch)
2087    if "template" in args and args["template"] is not None:
2088        fig.update_layout(template=args["template"], overwrite=True)
2089    fig.frames = frame_list if len(frames) > 1 else []
2090
2091    fig._px_trendlines = pd.DataFrame(trendline_rows)
2092
2093    configure_axes(args, constructor, fig, orders)
2094    configure_animation_controls(args, constructor, fig)
2095    return fig
2096
2097
2098def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
2099    # Build subplot specs
2100    specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)]
2101
2102    # Default row/column widths uniform
2103    column_widths = [1.0] * ncols
2104    row_heights = [1.0] * nrows
2105    facet_col_wrap = args.get("facet_col_wrap", 0)
2106
2107    # Build column_widths/row_heights
2108    if subplot_type == "xy":
2109        if bool(args.get("marginal_x", False)):
2110            if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
2111                main_size = 0.74
2112            else:
2113                main_size = 0.84
2114
2115            row_heights = [main_size] * (nrows - 1) + [1 - main_size]
2116            vertical_spacing = 0.01
2117        elif facet_col_wrap:
2118            vertical_spacing = args.get("facet_row_spacing", None) or 0.07
2119        else:
2120            vertical_spacing = args.get("facet_row_spacing", None) or 0.03
2121
2122        if bool(args.get("marginal_y", False)):
2123            if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
2124                main_size = 0.74
2125            else:
2126                main_size = 0.84
2127
2128            column_widths = [main_size] * (ncols - 1) + [1 - main_size]
2129            horizontal_spacing = 0.005
2130        else:
2131            horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
2132    else:
2133        # Other subplot types:
2134        #   'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
2135        #
2136        # We can customize subplot spacing per type once we enable faceting
2137        # for all plot types
2138        if facet_col_wrap:
2139            vertical_spacing = args.get("facet_row_spacing", None) or 0.07
2140        else:
2141            vertical_spacing = args.get("facet_row_spacing", None) or 0.03
2142        horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
2143
2144    if facet_col_wrap:
2145        subplot_labels = [None] * nrows * ncols
2146        while len(col_labels) < nrows * ncols:
2147            col_labels.append(None)
2148        for i in range(nrows):
2149            for j in range(ncols):
2150                subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]
2151
2152    def _spacing_error_translator(e, direction, facet_arg):
2153        """
2154        Translates the spacing errors thrown by the underlying make_subplots
2155        routine into one that describes an argument adjustable through px.
2156        """
2157        if ("%s spacing" % (direction,)) in e.args[0]:
2158            e.args = (
2159                e.args[0]
2160                + """
2161Use the {facet_arg} argument to adjust this spacing.""".format(
2162                    facet_arg=facet_arg
2163                ),
2164            )
2165            raise e
2166
2167    # Create figure with subplots
2168    try:
2169        fig = make_subplots(
2170            rows=nrows,
2171            cols=ncols,
2172            specs=specs,
2173            shared_xaxes="all",
2174            shared_yaxes="all",
2175            row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
2176            column_titles=[] if facet_col_wrap else col_labels,
2177            subplot_titles=subplot_labels if facet_col_wrap else [],
2178            horizontal_spacing=horizontal_spacing,
2179            vertical_spacing=vertical_spacing,
2180            row_heights=row_heights,
2181            column_widths=column_widths,
2182            start_cell="bottom-left",
2183        )
2184    except ValueError as e:
2185        _spacing_error_translator(e, "Horizontal", "facet_col_spacing")
2186        _spacing_error_translator(e, "Vertical", "facet_row_spacing")
2187
2188    # Remove explicit font size of row/col titles so template can take over
2189    for annot in fig.layout.annotations:
2190        annot.update(font=None)
2191
2192    return fig
2193