1import plotly.graph_objs as go 2import plotly.io as pio 3from collections import namedtuple, OrderedDict 4from ._special_inputs import IdentityMap, Constant, Range 5 6from _plotly_utils.basevalidators import ColorscaleValidator 7from plotly.colors import qualitative, sequential 8import math 9import pandas as pd 10import numpy as np 11 12from plotly.subplots import ( 13 make_subplots, 14 _set_trace_grid_reference, 15 _subplot_type_for_trace_type, 16) 17 18NO_COLOR = "px_no_color_constant" 19 20# Declare all supported attributes, across all plot types 21direct_attrables = ( 22 ["base", "x", "y", "z", "a", "b", "c", "r", "theta", "size", "x_start", "x_end"] 23 + ["hover_name", "text", "names", "values", "parents", "wide_cross"] 24 + ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"] 25 + ["error_z_minus", "lat", "lon", "locations", "animation_group"] 26) 27array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_variable"] 28group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] 29renameable_group_attrables = [ 30 "color", # renamed to marker.color or line.color in infer_config 31 "symbol", # renamed to marker.symbol in infer_config 32 "line_dash", # renamed to line.dash in infer_config 33] 34all_attrables = ( 35 direct_attrables + array_attrables + group_attrables + renameable_group_attrables 36) 37 38cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin] 39cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour] 40 41 42class PxDefaults(object): 43 __slots__ = [ 44 "template", 45 "width", 46 "height", 47 "color_discrete_sequence", 48 "color_discrete_map", 49 "color_continuous_scale", 50 "symbol_sequence", 51 "symbol_map", 52 "line_dash_sequence", 53 "line_dash_map", 54 "size_max", 55 "category_orders", 56 "labels", 57 ] 58 59 def __init__(self): 60 self.reset() 61 62 def reset(self): 63 self.template = None 64 self.width = None 65 self.height = None 66 self.color_discrete_sequence = None 67 self.color_discrete_map = {} 68 self.color_continuous_scale = None 69 self.symbol_sequence = None 70 self.symbol_map = {} 71 self.line_dash_sequence = None 72 self.line_dash_map = {} 73 self.size_max = 20 74 self.category_orders = {} 75 self.labels = {} 76 77 78defaults = PxDefaults() 79del PxDefaults 80 81 82MAPBOX_TOKEN = None 83 84 85def set_mapbox_access_token(token): 86 """ 87 Arguments: 88 token: A Mapbox token to be used in `plotly.express.scatter_mapbox` and \ 89 `plotly.express.line_mapbox` figures. See \ 90 https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details 91 """ 92 global MAPBOX_TOKEN 93 MAPBOX_TOKEN = token 94 95 96def get_trendline_results(fig): 97 """ 98 Extracts fit statistics for trendlines (when applied to figures generated with 99 the `trendline` argument set to `"ols"`). 100 101 Arguments: 102 fig: the output of a `plotly.express` charting call 103 Returns: 104 A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels` 105 results objects, along with columns identifying the subset of the data the 106 trendline was fit on. 107 """ 108 return fig._px_trendlines 109 110 111Mapping = namedtuple( 112 "Mapping", 113 [ 114 "show_in_trace_name", 115 "grouper", 116 "val_map", 117 "sequence", 118 "updater", 119 "variable", 120 "facet", 121 ], 122) 123TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"]) 124 125 126def get_label(args, column): 127 try: 128 return args["labels"][column] 129 except Exception: 130 return column 131 132 133def invert_label(args, column): 134 """Invert mapping. 135 Find key corresponding to value column in dict args["labels"]. 136 Returns `column` if the value does not exist. 137 """ 138 reversed_labels = {value: key for (key, value) in args["labels"].items()} 139 try: 140 return reversed_labels[column] 141 except Exception: 142 return column 143 144 145def _is_continuous(df, col_name): 146 return df[col_name].dtype.kind in "ifc" 147 148 149def get_decorated_label(args, column, role): 150 original_label = label = get_label(args, column) 151 if "histfunc" in args and ( 152 (role == "z") 153 or (role == "x" and "orientation" in args and args["orientation"] == "h") 154 or (role == "y" and "orientation" in args and args["orientation"] == "v") 155 ): 156 histfunc = args["histfunc"] or "count" 157 if histfunc != "count": 158 label = "%s of %s" % (histfunc, label) 159 else: 160 label = "count" 161 162 if "histnorm" in args and args["histnorm"] is not None: 163 if label == "count": 164 label = args["histnorm"] 165 else: 166 histnorm = args["histnorm"] 167 if histfunc == "sum": 168 if histnorm == "probability": 169 label = "%s of %s" % ("fraction", label) 170 elif histnorm == "percent": 171 label = "%s of %s" % (histnorm, label) 172 else: 173 label = "%s weighted by %s" % (histnorm, original_label) 174 elif histnorm == "probability": 175 label = "%s of sum of %s" % ("fraction", label) 176 elif histnorm == "percent": 177 label = "%s of sum of %s" % ("percent", label) 178 else: 179 label = "%s of %s" % (histnorm, label) 180 181 if "barnorm" in args and args["barnorm"] is not None: 182 label = "%s (normalized as %s)" % (label, args["barnorm"]) 183 184 return label 185 186 187def make_mapping(args, variable): 188 if variable == "line_group" or variable == "animation_frame": 189 return Mapping( 190 show_in_trace_name=False, 191 grouper=args[variable], 192 val_map={}, 193 sequence=[""], 194 variable=variable, 195 updater=(lambda trace, v: v), 196 facet=None, 197 ) 198 if variable == "facet_row" or variable == "facet_col": 199 letter = "x" if variable == "facet_col" else "y" 200 return Mapping( 201 show_in_trace_name=False, 202 variable=letter, 203 grouper=args[variable], 204 val_map={}, 205 sequence=[i for i in range(1, 1000)], 206 updater=(lambda trace, v: v), 207 facet="row" if variable == "facet_row" else "col", 208 ) 209 (parent, variable) = variable.split(".") 210 vprefix = variable 211 arg_name = variable 212 if variable == "color": 213 vprefix = "color_discrete" 214 if variable == "dash": 215 arg_name = "line_dash" 216 vprefix = "line_dash" 217 if args[vprefix + "_map"] == "identity": 218 val_map = IdentityMap() 219 else: 220 val_map = args[vprefix + "_map"].copy() 221 return Mapping( 222 show_in_trace_name=True, 223 variable=variable, 224 grouper=args[arg_name], 225 val_map=val_map, 226 sequence=args[vprefix + "_sequence"], 227 updater=lambda trace, v: trace.update({parent: {variable: v}}), 228 facet=None, 229 ) 230 231 232def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): 233 """Populates a dict with arguments to update trace 234 235 Parameters 236 ---------- 237 args : dict 238 args to be used for the trace 239 trace_spec : NamedTuple 240 which kind of trace to be used (has constructor, marginal etc. 241 attributes) 242 trace_data : pandas DataFrame 243 data 244 mapping_labels : dict 245 to be used for hovertemplate 246 sizeref : float 247 marker sizeref 248 249 Returns 250 ------- 251 trace_patch : dict 252 dict to be used to update trace 253 fit_results : dict 254 fit information to be used for trendlines 255 """ 256 if "line_close" in args and args["line_close"]: 257 trace_data = trace_data.append(trace_data.iloc[0]) 258 trace_patch = trace_spec.trace_patch.copy() or {} 259 fit_results = None 260 hover_header = "" 261 for attr_name in trace_spec.attrs: 262 attr_value = args[attr_name] 263 attr_label = get_decorated_label(args, attr_value, attr_name) 264 if attr_name == "dimensions": 265 dims = [ 266 (name, column) 267 for (name, column) in trace_data.iteritems() 268 if ((not attr_value) or (name in attr_value)) 269 and ( 270 trace_spec.constructor != go.Parcoords 271 or _is_continuous(args["data_frame"], name) 272 ) 273 and ( 274 trace_spec.constructor != go.Parcats 275 or (attr_value is not None and name in attr_value) 276 or len(args["data_frame"][name].unique()) 277 <= args["dimensions_max_cardinality"] 278 ) 279 ] 280 trace_patch["dimensions"] = [ 281 dict(label=get_label(args, name), values=column) 282 for (name, column) in dims 283 ] 284 if trace_spec.constructor == go.Splom: 285 for d in trace_patch["dimensions"]: 286 d["axis"] = dict(matches=True) 287 mapping_labels["%{xaxis.title.text}"] = "%{x}" 288 mapping_labels["%{yaxis.title.text}"] = "%{y}" 289 290 elif attr_value is not None: 291 if attr_name == "size": 292 if "marker" not in trace_patch: 293 trace_patch["marker"] = dict() 294 trace_patch["marker"]["size"] = trace_data[attr_value] 295 trace_patch["marker"]["sizemode"] = "area" 296 trace_patch["marker"]["sizeref"] = sizeref 297 mapping_labels[attr_label] = "%{marker.size}" 298 elif attr_name == "marginal_x": 299 if trace_spec.constructor == go.Histogram: 300 mapping_labels["count"] = "%{y}" 301 elif attr_name == "marginal_y": 302 if trace_spec.constructor == go.Histogram: 303 mapping_labels["count"] = "%{x}" 304 elif attr_name == "trendline": 305 if ( 306 attr_value in ["ols", "lowess"] 307 and args["x"] 308 and args["y"] 309 and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 310 ): 311 import statsmodels.api as sm 312 313 # sorting is bad but trace_specs with "trendline" have no other attrs 314 sorted_trace_data = trace_data.sort_values(by=args["x"]) 315 y = sorted_trace_data[args["y"]].values 316 x = sorted_trace_data[args["x"]].values 317 318 if x.dtype.type == np.datetime64: 319 x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds 320 elif x.dtype.type == np.object_: 321 try: 322 x = x.astype(np.float64) 323 except ValueError: 324 raise ValueError( 325 "Could not convert value of 'x' ('%s') into a numeric type. " 326 "If 'x' contains stringified dates, please convert to a datetime column." 327 % args["x"] 328 ) 329 if y.dtype.type == np.object_: 330 try: 331 y = y.astype(np.float64) 332 except ValueError: 333 raise ValueError( 334 "Could not convert value of 'y' into a numeric type." 335 ) 336 337 # preserve original values of "x" in case they're dates 338 trace_patch["x"] = sorted_trace_data[args["x"]][ 339 np.logical_not(np.logical_or(np.isnan(y), np.isnan(x))) 340 ] 341 342 if attr_value == "lowess": 343 # missing ='drop' is the default value for lowess but not for OLS (None) 344 # we force it here in case statsmodels change their defaults 345 trendline = sm.nonparametric.lowess(y, x, missing="drop") 346 trace_patch["y"] = trendline[:, 1] 347 hover_header = "<b>LOWESS trendline</b><br><br>" 348 elif attr_value == "ols": 349 fit_results = sm.OLS( 350 y, sm.add_constant(x), missing="drop" 351 ).fit() 352 trace_patch["y"] = fit_results.predict() 353 hover_header = "<b>OLS trendline</b><br>" 354 if len(fit_results.params) == 2: 355 hover_header += "%s = %g * %s + %g<br>" % ( 356 args["y"], 357 fit_results.params[1], 358 args["x"], 359 fit_results.params[0], 360 ) 361 else: 362 hover_header += "%s = %g<br>" % ( 363 args["y"], 364 fit_results.params[0], 365 ) 366 hover_header += ( 367 "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared 368 ) 369 mapping_labels[get_label(args, args["x"])] = "%{x}" 370 mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>" 371 elif attr_name.startswith("error"): 372 error_xy = attr_name[:7] 373 arr = "arrayminus" if attr_name.endswith("minus") else "array" 374 if error_xy not in trace_patch: 375 trace_patch[error_xy] = {} 376 trace_patch[error_xy][arr] = trace_data[attr_value] 377 elif attr_name == "custom_data": 378 if len(attr_value) > 0: 379 # here we store a data frame in customdata, and it's serialized 380 # as a list of row lists, which is what we want 381 trace_patch["customdata"] = trace_data[attr_value] 382 elif attr_name == "hover_name": 383 if trace_spec.constructor not in [ 384 go.Histogram, 385 go.Histogram2d, 386 go.Histogram2dContour, 387 ]: 388 trace_patch["hovertext"] = trace_data[attr_value] 389 if hover_header == "": 390 hover_header = "<b>%{hovertext}</b><br><br>" 391 elif attr_name == "hover_data": 392 if trace_spec.constructor not in [ 393 go.Histogram, 394 go.Histogram2d, 395 go.Histogram2dContour, 396 ]: 397 hover_is_dict = isinstance(attr_value, dict) 398 customdata_cols = args.get("custom_data") or [] 399 for col in attr_value: 400 if hover_is_dict and not attr_value[col]: 401 continue 402 if col in [ 403 args.get("x", None), 404 args.get("y", None), 405 args.get("z", None), 406 args.get("base", None), 407 ]: 408 continue 409 try: 410 position = args["custom_data"].index(col) 411 except (ValueError, AttributeError, KeyError): 412 position = len(customdata_cols) 413 customdata_cols.append(col) 414 attr_label_col = get_decorated_label(args, col, None) 415 mapping_labels[attr_label_col] = "%%{customdata[%d]}" % ( 416 position 417 ) 418 419 if len(customdata_cols) > 0: 420 # here we store a data frame in customdata, and it's serialized 421 # as a list of row lists, which is what we want 422 trace_patch["customdata"] = trace_data[customdata_cols] 423 elif attr_name == "color": 424 if trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]: 425 trace_patch["z"] = trace_data[attr_value] 426 trace_patch["coloraxis"] = "coloraxis1" 427 mapping_labels[attr_label] = "%{z}" 428 elif trace_spec.constructor in [ 429 go.Sunburst, 430 go.Treemap, 431 go.Pie, 432 go.Funnelarea, 433 ]: 434 if "marker" not in trace_patch: 435 trace_patch["marker"] = dict() 436 437 if args.get("color_is_continuous"): 438 trace_patch["marker"]["colors"] = trace_data[attr_value] 439 trace_patch["marker"]["coloraxis"] = "coloraxis1" 440 mapping_labels[attr_label] = "%{color}" 441 else: 442 trace_patch["marker"]["colors"] = [] 443 if args["color_discrete_map"] is not None: 444 mapping = args["color_discrete_map"].copy() 445 else: 446 mapping = {} 447 for cat in trace_data[attr_value]: 448 if mapping.get(cat) is None: 449 mapping[cat] = args["color_discrete_sequence"][ 450 len(mapping) % len(args["color_discrete_sequence"]) 451 ] 452 trace_patch["marker"]["colors"].append(mapping[cat]) 453 else: 454 colorable = "marker" 455 if trace_spec.constructor in [go.Parcats, go.Parcoords]: 456 colorable = "line" 457 if colorable not in trace_patch: 458 trace_patch[colorable] = dict() 459 trace_patch[colorable]["color"] = trace_data[attr_value] 460 trace_patch[colorable]["coloraxis"] = "coloraxis1" 461 mapping_labels[attr_label] = "%%{%s.color}" % colorable 462 elif attr_name == "animation_group": 463 trace_patch["ids"] = trace_data[attr_value] 464 elif attr_name == "locations": 465 trace_patch[attr_name] = trace_data[attr_value] 466 mapping_labels[attr_label] = "%{location}" 467 elif attr_name == "values": 468 trace_patch[attr_name] = trace_data[attr_value] 469 _label = "value" if attr_label == "values" else attr_label 470 mapping_labels[_label] = "%{value}" 471 elif attr_name == "parents": 472 trace_patch[attr_name] = trace_data[attr_value] 473 _label = "parent" if attr_label == "parents" else attr_label 474 mapping_labels[_label] = "%{parent}" 475 elif attr_name == "ids": 476 trace_patch[attr_name] = trace_data[attr_value] 477 _label = "id" if attr_label == "ids" else attr_label 478 mapping_labels[_label] = "%{id}" 479 elif attr_name == "names": 480 if trace_spec.constructor in [ 481 go.Sunburst, 482 go.Treemap, 483 go.Pie, 484 go.Funnelarea, 485 ]: 486 trace_patch["labels"] = trace_data[attr_value] 487 _label = "label" if attr_label == "names" else attr_label 488 mapping_labels[_label] = "%{label}" 489 else: 490 trace_patch[attr_name] = trace_data[attr_value] 491 else: 492 trace_patch[attr_name] = trace_data[attr_value] 493 mapping_labels[attr_label] = "%%{%s}" % attr_name 494 elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or ( 495 trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour] 496 and attr_name == "z" 497 ): 498 # ensure that stuff like "count" gets into the hoverlabel 499 mapping_labels[attr_label] = "%%{%s}" % attr_name 500 if trace_spec.constructor not in [go.Parcoords, go.Parcats]: 501 # Modify mapping_labels according to hover_data keys 502 # if hover_data is a dict 503 mapping_labels_copy = OrderedDict(mapping_labels) 504 if args["hover_data"] and isinstance(args["hover_data"], dict): 505 for k, v in mapping_labels.items(): 506 # We need to invert the mapping here 507 k_args = invert_label(args, k) 508 if k_args in args["hover_data"]: 509 formatter = args["hover_data"][k_args][0] 510 if formatter: 511 if isinstance(formatter, str): 512 mapping_labels_copy[k] = v.replace("}", "%s}" % formatter) 513 else: 514 _ = mapping_labels_copy.pop(k) 515 hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()] 516 trace_patch["hovertemplate"] = hover_header + "<br>".join(hover_lines) 517 trace_patch["hovertemplate"] += "<extra></extra>" 518 return trace_patch, fit_results 519 520 521def configure_axes(args, constructor, fig, orders): 522 configurators = { 523 go.Scatter3d: configure_3d_axes, 524 go.Scatterternary: configure_ternary_axes, 525 go.Scatterpolar: configure_polar_axes, 526 go.Scatterpolargl: configure_polar_axes, 527 go.Barpolar: configure_polar_axes, 528 go.Scattermapbox: configure_mapbox, 529 go.Choroplethmapbox: configure_mapbox, 530 go.Densitymapbox: configure_mapbox, 531 go.Scattergeo: configure_geo, 532 go.Choropleth: configure_geo, 533 } 534 for c in cartesians: 535 configurators[c] = configure_cartesian_axes 536 if constructor in configurators: 537 configurators[constructor](args, fig, orders) 538 539 540def set_cartesian_axis_opts(args, axis, letter, orders): 541 log_key = "log_" + letter 542 range_key = "range_" + letter 543 if log_key in args and args[log_key]: 544 axis["type"] = "log" 545 if range_key in args and args[range_key]: 546 axis["range"] = [math.log(r, 10) for r in args[range_key]] 547 elif range_key in args and args[range_key]: 548 axis["range"] = args[range_key] 549 550 if args[letter] in orders: 551 axis["categoryorder"] = "array" 552 axis["categoryarray"] = ( 553 orders[args[letter]] 554 if isinstance(axis, go.layout.XAxis) 555 else list(reversed(orders[args[letter]])) 556 ) 557 558 559def configure_cartesian_marginal_axes(args, fig, orders): 560 561 if "histogram" in [args["marginal_x"], args["marginal_y"]]: 562 fig.layout["barmode"] = "overlay" 563 564 nrows = len(fig._grid_ref) 565 ncols = len(fig._grid_ref[0]) 566 567 # Set y-axis titles and axis options in the left-most column 568 for yaxis in fig.select_yaxes(col=1): 569 set_cartesian_axis_opts(args, yaxis, "y", orders) 570 571 # Set x-axis titles and axis options in the bottom-most row 572 for xaxis in fig.select_xaxes(row=1): 573 set_cartesian_axis_opts(args, xaxis, "x", orders) 574 575 # Configure axis ticks on marginal subplots 576 if args["marginal_x"]: 577 fig.update_yaxes( 578 showticklabels=False, showline=False, ticks="", range=None, row=nrows 579 ) 580 if args["template"].layout.yaxis.showgrid is None: 581 fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows) 582 if args["template"].layout.xaxis.showgrid is None: 583 fig.update_xaxes(showgrid=True, row=nrows) 584 585 if args["marginal_y"]: 586 fig.update_xaxes( 587 showticklabels=False, showline=False, ticks="", range=None, col=ncols 588 ) 589 if args["template"].layout.xaxis.showgrid is None: 590 fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols) 591 if args["template"].layout.yaxis.showgrid is None: 592 fig.update_yaxes(showgrid=True, col=ncols) 593 594 # Add axis titles to non-marginal subplots 595 y_title = get_decorated_label(args, args["y"], "y") 596 if args["marginal_x"]: 597 fig.update_yaxes(title_text=y_title, row=1, col=1) 598 else: 599 for row in range(1, nrows + 1): 600 fig.update_yaxes(title_text=y_title, row=row, col=1) 601 602 x_title = get_decorated_label(args, args["x"], "x") 603 if args["marginal_y"]: 604 fig.update_xaxes(title_text=x_title, row=1, col=1) 605 else: 606 for col in range(1, ncols + 1): 607 fig.update_xaxes(title_text=x_title, row=1, col=col) 608 609 # Configure axis type across all x-axes 610 if "log_x" in args and args["log_x"]: 611 fig.update_xaxes(type="log") 612 613 # Configure axis type across all y-axes 614 if "log_y" in args and args["log_y"]: 615 fig.update_yaxes(type="log") 616 617 # Configure matching and axis type for marginal y-axes 618 matches_y = "y" + str(ncols + 1) 619 if args["marginal_x"]: 620 for row in range(2, nrows + 1, 2): 621 fig.update_yaxes(matches=matches_y, type=None, row=row) 622 623 if args["marginal_y"]: 624 for col in range(2, ncols + 1, 2): 625 fig.update_xaxes(matches="x2", type=None, col=col) 626 627 628def configure_cartesian_axes(args, fig, orders): 629 if ("marginal_x" in args and args["marginal_x"]) or ( 630 "marginal_y" in args and args["marginal_y"] 631 ): 632 configure_cartesian_marginal_axes(args, fig, orders) 633 return 634 635 # Set y-axis titles and axis options in the left-most column 636 y_title = get_decorated_label(args, args["y"], "y") 637 for yaxis in fig.select_yaxes(col=1): 638 yaxis.update(title_text=y_title) 639 set_cartesian_axis_opts(args, yaxis, "y", orders) 640 641 # Set x-axis titles and axis options in the bottom-most row 642 x_title = get_decorated_label(args, args["x"], "x") 643 for xaxis in fig.select_xaxes(row=1): 644 if "is_timeline" not in args: 645 xaxis.update(title_text=x_title) 646 set_cartesian_axis_opts(args, xaxis, "x", orders) 647 648 # Configure axis type across all x-axes 649 if "log_x" in args and args["log_x"]: 650 fig.update_xaxes(type="log") 651 652 # Configure axis type across all y-axes 653 if "log_y" in args and args["log_y"]: 654 fig.update_yaxes(type="log") 655 656 if "is_timeline" in args: 657 fig.update_xaxes(type="date") 658 659 660def configure_ternary_axes(args, fig, orders): 661 fig.update_ternaries( 662 aaxis=dict(title_text=get_label(args, args["a"])), 663 baxis=dict(title_text=get_label(args, args["b"])), 664 caxis=dict(title_text=get_label(args, args["c"])), 665 ) 666 667 668def configure_polar_axes(args, fig, orders): 669 patch = dict( 670 angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]), 671 radialaxis=dict(), 672 ) 673 674 for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]: 675 if args[var] in orders: 676 patch[axis]["categoryorder"] = "array" 677 patch[axis]["categoryarray"] = orders[args[var]] 678 679 radialaxis = patch["radialaxis"] 680 if args["log_r"]: 681 radialaxis["type"] = "log" 682 if args["range_r"]: 683 radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]] 684 else: 685 if args["range_r"]: 686 radialaxis["range"] = args["range_r"] 687 688 if args["range_theta"]: 689 patch["sector"] = args["range_theta"] 690 fig.update_polars(patch) 691 692 693def configure_3d_axes(args, fig, orders): 694 patch = dict( 695 xaxis=dict(title_text=get_label(args, args["x"])), 696 yaxis=dict(title_text=get_label(args, args["y"])), 697 zaxis=dict(title_text=get_label(args, args["z"])), 698 ) 699 700 for letter in ["x", "y", "z"]: 701 axis = patch[letter + "axis"] 702 if args["log_" + letter]: 703 axis["type"] = "log" 704 if args["range_" + letter]: 705 axis["range"] = [math.log(x, 10) for x in args["range_" + letter]] 706 else: 707 if args["range_" + letter]: 708 axis["range"] = args["range_" + letter] 709 if args[letter] in orders: 710 axis["categoryorder"] = "array" 711 axis["categoryarray"] = orders[args[letter]] 712 fig.update_scenes(patch) 713 714 715def configure_mapbox(args, fig, orders): 716 center = args["center"] 717 if not center and "lat" in args and "lon" in args: 718 center = dict( 719 lat=args["data_frame"][args["lat"]].mean(), 720 lon=args["data_frame"][args["lon"]].mean(), 721 ) 722 fig.update_mapboxes( 723 accesstoken=MAPBOX_TOKEN, 724 center=center, 725 zoom=args["zoom"], 726 style=args["mapbox_style"], 727 ) 728 729 730def configure_geo(args, fig, orders): 731 fig.update_geos( 732 center=args["center"], 733 scope=args["scope"], 734 fitbounds=args["fitbounds"], 735 visible=args["basemap_visible"], 736 projection=dict(type=args["projection"]), 737 ) 738 739 740def configure_animation_controls(args, constructor, fig): 741 def frame_args(duration): 742 return { 743 "frame": {"duration": duration, "redraw": constructor != go.Scatter}, 744 "mode": "immediate", 745 "fromcurrent": True, 746 "transition": {"duration": duration, "easing": "linear"}, 747 } 748 749 if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1: 750 fig.layout.updatemenus = [ 751 { 752 "buttons": [ 753 { 754 "args": [None, frame_args(500)], 755 "label": "▶", 756 "method": "animate", 757 }, 758 { 759 "args": [[None], frame_args(0)], 760 "label": "◼", 761 "method": "animate", 762 }, 763 ], 764 "direction": "left", 765 "pad": {"r": 10, "t": 70}, 766 "showactive": False, 767 "type": "buttons", 768 "x": 0.1, 769 "xanchor": "right", 770 "y": 0, 771 "yanchor": "top", 772 } 773 ] 774 fig.layout.sliders = [ 775 { 776 "active": 0, 777 "yanchor": "top", 778 "xanchor": "left", 779 "currentvalue": { 780 "prefix": get_label(args, args["animation_frame"]) + "=" 781 }, 782 "pad": {"b": 10, "t": 60}, 783 "len": 0.9, 784 "x": 0.1, 785 "y": 0, 786 "steps": [ 787 { 788 "args": [[f.name], frame_args(0)], 789 "label": f.name, 790 "method": "animate", 791 } 792 for f in fig.frames 793 ], 794 } 795 ] 796 797 798def make_trace_spec(args, constructor, attrs, trace_patch): 799 if constructor in [go.Scatter, go.Scatterpolar]: 800 if "render_mode" in args and ( 801 args["render_mode"] == "webgl" 802 or ( 803 args["render_mode"] == "auto" 804 and len(args["data_frame"]) > 1000 805 and args["animation_frame"] is None 806 ) 807 ): 808 if constructor == go.Scatter: 809 constructor = go.Scattergl 810 if "orientation" in trace_patch: 811 del trace_patch["orientation"] 812 else: 813 constructor = go.Scatterpolargl 814 # Create base trace specification 815 result = [TraceSpec(constructor, attrs, trace_patch, None)] 816 817 # Add marginal trace specifications 818 for letter in ["x", "y"]: 819 if "marginal_" + letter in args and args["marginal_" + letter]: 820 trace_spec = None 821 axis_map = dict( 822 xaxis="x1" if letter == "x" else "x2", 823 yaxis="y1" if letter == "y" else "y2", 824 ) 825 if args["marginal_" + letter] == "histogram": 826 trace_spec = TraceSpec( 827 constructor=go.Histogram, 828 attrs=[letter, "marginal_" + letter], 829 trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map), 830 marginal=letter, 831 ) 832 elif args["marginal_" + letter] == "violin": 833 trace_spec = TraceSpec( 834 constructor=go.Violin, 835 attrs=[letter, "hover_name", "hover_data"], 836 trace_patch=dict(scalegroup=letter), 837 marginal=letter, 838 ) 839 elif args["marginal_" + letter] == "box": 840 trace_spec = TraceSpec( 841 constructor=go.Box, 842 attrs=[letter, "hover_name", "hover_data"], 843 trace_patch=dict(notched=True), 844 marginal=letter, 845 ) 846 elif args["marginal_" + letter] == "rug": 847 symbols = {"x": "line-ns-open", "y": "line-ew-open"} 848 trace_spec = TraceSpec( 849 constructor=go.Box, 850 attrs=[letter, "hover_name", "hover_data"], 851 trace_patch=dict( 852 fillcolor="rgba(255,255,255,0)", 853 line={"color": "rgba(255,255,255,0)"}, 854 boxpoints="all", 855 jitter=0, 856 hoveron="points", 857 marker={"symbol": symbols[letter]}, 858 ), 859 marginal=letter, 860 ) 861 if "color" in attrs or "color" not in args: 862 if "marker" not in trace_spec.trace_patch: 863 trace_spec.trace_patch["marker"] = dict() 864 first_default_color = args["color_continuous_scale"][0] 865 trace_spec.trace_patch["marker"]["color"] = first_default_color 866 result.append(trace_spec) 867 868 # Add trendline trace specifications 869 if "trendline" in args and args["trendline"]: 870 trace_spec = TraceSpec( 871 constructor=go.Scattergl if constructor == go.Scattergl else go.Scatter, 872 attrs=["trendline"], 873 trace_patch=dict(mode="lines"), 874 marginal=None, 875 ) 876 if args["trendline_color_override"]: 877 trace_spec.trace_patch["line"] = dict( 878 color=args["trendline_color_override"] 879 ) 880 result.append(trace_spec) 881 return result 882 883 884def one_group(x): 885 return "" 886 887 888def apply_default_cascade(args): 889 # first we apply px.defaults to unspecified args 890 891 for param in defaults.__slots__: 892 if param in args and args[param] is None: 893 args[param] = getattr(defaults, param) 894 895 # load the default template if set, otherwise "plotly" 896 if args["template"] is None: 897 if pio.templates.default is not None: 898 args["template"] = pio.templates.default 899 else: 900 args["template"] = "plotly" 901 902 try: 903 # retrieve the actual template if we were given a name 904 args["template"] = pio.templates[args["template"]] 905 except Exception: 906 # otherwise try to build a real template 907 args["template"] = go.layout.Template(args["template"]) 908 909 # if colors not set explicitly or in px.defaults, defer to a template 910 # if the template doesn't have one, we set some final fallback defaults 911 if "color_continuous_scale" in args: 912 if ( 913 args["color_continuous_scale"] is None 914 and args["template"].layout.colorscale.sequential 915 ): 916 args["color_continuous_scale"] = [ 917 x[1] for x in args["template"].layout.colorscale.sequential 918 ] 919 if args["color_continuous_scale"] is None: 920 args["color_continuous_scale"] = sequential.Viridis 921 922 if "color_discrete_sequence" in args: 923 if args["color_discrete_sequence"] is None and args["template"].layout.colorway: 924 args["color_discrete_sequence"] = args["template"].layout.colorway 925 if args["color_discrete_sequence"] is None: 926 args["color_discrete_sequence"] = qualitative.D3 927 928 # if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults, 929 # see if we can defer to template. If not, set reasonable defaults 930 if "symbol_sequence" in args: 931 if args["symbol_sequence"] is None and args["template"].data.scatter: 932 args["symbol_sequence"] = [ 933 scatter.marker.symbol for scatter in args["template"].data.scatter 934 ] 935 if not args["symbol_sequence"] or not any(args["symbol_sequence"]): 936 args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"] 937 938 if "line_dash_sequence" in args: 939 if args["line_dash_sequence"] is None and args["template"].data.scatter: 940 args["line_dash_sequence"] = [ 941 scatter.line.dash for scatter in args["template"].data.scatter 942 ] 943 if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]): 944 args["line_dash_sequence"] = [ 945 "solid", 946 "dot", 947 "dash", 948 "longdash", 949 "dashdot", 950 "longdashdot", 951 ] 952 953 954def _check_name_not_reserved(field_name, reserved_names): 955 if field_name not in reserved_names: 956 return field_name 957 else: 958 raise NameError( 959 "A name conflict was encountered for argument '%s'. " 960 "A column or index with name '%s' is ambiguous." % (field_name, field_name) 961 ) 962 963 964def _get_reserved_col_names(args): 965 """ 966 This function builds a list of columns of the data_frame argument used 967 as arguments, either as str/int arguments or given as columns 968 (pandas series type). 969 """ 970 df = args["data_frame"] 971 reserved_names = set() 972 for field in args: 973 if field not in all_attrables: 974 continue 975 names = args[field] if field in array_attrables else [args[field]] 976 if names is None: 977 continue 978 for arg in names: 979 if arg is None: 980 continue 981 elif isinstance(arg, str): # no need to add ints since kw arg are not ints 982 reserved_names.add(arg) 983 elif isinstance(arg, pd.Series): 984 arg_name = arg.name 985 if arg_name and hasattr(df, arg_name): 986 in_df = arg is df[arg_name] 987 if in_df: 988 reserved_names.add(arg_name) 989 elif arg is df.index and arg.name is not None: 990 reserved_names.add(arg.name) 991 992 return reserved_names 993 994 995def _is_col_list(df_input, arg): 996 """Returns True if arg looks like it's a list of columns or references to columns 997 in df_input, and False otherwise (in which case it's assumed to be a single column 998 or reference to a column). 999 """ 1000 if arg is None or isinstance(arg, str) or isinstance(arg, int): 1001 return False 1002 if isinstance(arg, pd.MultiIndex): 1003 return False # just to keep existing behaviour for now 1004 try: 1005 iter(arg) 1006 except TypeError: 1007 return False # not iterable 1008 for c in arg: 1009 if isinstance(c, str) or isinstance(c, int): 1010 if df_input is None or c not in df_input.columns: 1011 return False 1012 else: 1013 try: 1014 iter(c) 1015 except TypeError: 1016 return False # not iterable 1017 return True 1018 1019 1020def _isinstance_listlike(x): 1021 """Returns True if x is an iterable which can be transformed into a pandas Series, 1022 False for the other types of possible values of a `hover_data` dict. 1023 A tuple of length 2 is a special case corresponding to a (format, data) tuple. 1024 """ 1025 if ( 1026 isinstance(x, str) 1027 or (isinstance(x, tuple) and len(x) == 2) 1028 or isinstance(x, bool) 1029 or x is None 1030 ): 1031 return False 1032 else: 1033 return True 1034 1035 1036def _escape_col_name(df_input, col_name, extra): 1037 while df_input is not None and (col_name in df_input.columns or col_name in extra): 1038 col_name = "_" + col_name 1039 return col_name 1040 1041 1042def to_unindexed_series(x): 1043 """ 1044 assuming x is list-like or even an existing pd.Series, return a new pd.Series with 1045 no index, without extracting the data from an existing Series via numpy, which 1046 seems to mangle datetime columns. Stripping the index from existing pd.Series is 1047 required to get things to match up right in the new DataFrame we're building 1048 """ 1049 return pd.Series(x).reset_index(drop=True) 1050 1051 1052def process_args_into_dataframe(args, wide_mode, var_name, value_name): 1053 """ 1054 After this function runs, the `all_attrables` keys of `args` all contain only 1055 references to columns of `df_output`. This function handles the extraction of data 1056 from `args["attrable"]` and column-name-generation as appropriate, and adds the 1057 data to `df_output` and then replaces `args["attrable"]` with the appropriate 1058 reference. 1059 """ 1060 1061 df_input = args["data_frame"] 1062 df_provided = df_input is not None 1063 1064 df_output = pd.DataFrame() 1065 constants = dict() 1066 ranges = list() 1067 wide_id_vars = set() 1068 reserved_names = _get_reserved_col_names(args) if df_provided else set() 1069 1070 # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords 1071 if "dimensions" in args and args["dimensions"] is None: 1072 if not df_provided: 1073 raise ValueError( 1074 "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." 1075 ) 1076 else: 1077 df_output[df_input.columns] = df_input[df_input.columns] 1078 1079 # hover_data is a dict 1080 hover_data_is_dict = ( 1081 "hover_data" in args 1082 and args["hover_data"] 1083 and isinstance(args["hover_data"], dict) 1084 ) 1085 # If dict, convert all values of hover_data to tuples to simplify processing 1086 if hover_data_is_dict: 1087 for k in args["hover_data"]: 1088 if _isinstance_listlike(args["hover_data"][k]): 1089 args["hover_data"][k] = (True, args["hover_data"][k]) 1090 if not isinstance(args["hover_data"][k], tuple): 1091 args["hover_data"][k] = (args["hover_data"][k], None) 1092 if df_provided and args["hover_data"][k][1] is not None and k in df_input: 1093 raise ValueError( 1094 "Ambiguous input: values for '%s' appear both in hover_data and data_frame" 1095 % k 1096 ) 1097 # Loop over possible arguments 1098 for field_name in all_attrables: 1099 # Massaging variables 1100 argument_list = ( 1101 [args.get(field_name)] 1102 if field_name not in array_attrables 1103 else args.get(field_name) 1104 ) 1105 # argument not specified, continue 1106 if argument_list is None or argument_list is [None]: 1107 continue 1108 # Argument name: field_name if the argument is not a list 1109 # Else we give names like ["hover_data_0, hover_data_1"] etc. 1110 field_list = ( 1111 [field_name] 1112 if field_name not in array_attrables 1113 else [field_name + "_" + str(i) for i in range(len(argument_list))] 1114 ) 1115 # argument_list and field_list ready, iterate over them 1116 # Core of the loop starts here 1117 for i, (argument, field) in enumerate(zip(argument_list, field_list)): 1118 length = len(df_output) 1119 if argument is None: 1120 continue 1121 col_name = None 1122 # Case of multiindex 1123 if isinstance(argument, pd.MultiIndex): 1124 raise TypeError( 1125 "Argument '%s' is a pandas MultiIndex. " 1126 "pandas MultiIndex is not supported by plotly express " 1127 "at the moment." % field 1128 ) 1129 # ----------------- argument is a special value ---------------------- 1130 if isinstance(argument, Constant) or isinstance(argument, Range): 1131 col_name = _check_name_not_reserved( 1132 str(argument.label) if argument.label is not None else field, 1133 reserved_names, 1134 ) 1135 if isinstance(argument, Constant): 1136 constants[col_name] = argument.value 1137 else: 1138 ranges.append(col_name) 1139 # ----------------- argument is likely a col name ---------------------- 1140 elif isinstance(argument, str) or not hasattr(argument, "__len__"): 1141 if ( 1142 field_name == "hover_data" 1143 and hover_data_is_dict 1144 and args["hover_data"][str(argument)][1] is not None 1145 ): 1146 # hover_data has onboard data 1147 # previously-checked to have no name-conflict with data_frame 1148 col_name = str(argument) 1149 real_argument = args["hover_data"][col_name][1] 1150 1151 if length and len(real_argument) != length: 1152 raise ValueError( 1153 "All arguments should have the same length. " 1154 "The length of hover_data key `%s` is %d, whereas the " 1155 "length of previously-processed arguments %s is %d" 1156 % ( 1157 argument, 1158 len(real_argument), 1159 str(list(df_output.columns)), 1160 length, 1161 ) 1162 ) 1163 df_output[col_name] = to_unindexed_series(real_argument) 1164 elif not df_provided: 1165 raise ValueError( 1166 "String or int arguments are only possible when a " 1167 "DataFrame or an array is provided in the `data_frame` " 1168 "argument. No DataFrame was provided, but argument " 1169 "'%s' is of type str or int." % field 1170 ) 1171 # Check validity of column name 1172 elif argument not in df_input.columns: 1173 if wide_mode and argument in (value_name, var_name): 1174 continue 1175 else: 1176 err_msg = ( 1177 "Value of '%s' is not the name of a column in 'data_frame'. " 1178 "Expected one of %s but received: %s" 1179 % (field, str(list(df_input.columns)), argument) 1180 ) 1181 if argument == "index": 1182 err_msg += "\n To use the index, pass it in directly as `df.index`." 1183 raise ValueError(err_msg) 1184 elif length and len(df_input[argument]) != length: 1185 raise ValueError( 1186 "All arguments should have the same length. " 1187 "The length of column argument `df[%s]` is %d, whereas the " 1188 "length of previously-processed arguments %s is %d" 1189 % ( 1190 field, 1191 len(df_input[argument]), 1192 str(list(df_output.columns)), 1193 length, 1194 ) 1195 ) 1196 else: 1197 col_name = str(argument) 1198 df_output[col_name] = to_unindexed_series(df_input[argument]) 1199 # ----------------- argument is likely a column / array / list.... ------- 1200 else: 1201 if df_provided and hasattr(argument, "name"): 1202 if argument is df_input.index: 1203 if argument.name is None or argument.name in df_input: 1204 col_name = "index" 1205 else: 1206 col_name = argument.name 1207 col_name = _escape_col_name( 1208 df_input, col_name, [var_name, value_name] 1209 ) 1210 else: 1211 if ( 1212 argument.name is not None 1213 and argument.name in df_input 1214 and argument is df_input[argument.name] 1215 ): 1216 col_name = argument.name 1217 if col_name is None: # numpy array, list... 1218 col_name = _check_name_not_reserved(field, reserved_names) 1219 1220 if length and len(argument) != length: 1221 raise ValueError( 1222 "All arguments should have the same length. " 1223 "The length of argument `%s` is %d, whereas the " 1224 "length of previously-processed arguments %s is %d" 1225 % (field, len(argument), str(list(df_output.columns)), length) 1226 ) 1227 df_output[str(col_name)] = to_unindexed_series(argument) 1228 1229 # Finally, update argument with column name now that column exists 1230 assert col_name is not None, ( 1231 "Data-frame processing failure, likely due to a internal bug. " 1232 "Please report this to " 1233 "https://github.com/plotly/plotly.py/issues/new and we will try to " 1234 "replicate and fix it." 1235 ) 1236 if field_name not in array_attrables: 1237 args[field_name] = str(col_name) 1238 elif isinstance(args[field_name], dict): 1239 pass 1240 else: 1241 args[field_name][i] = str(col_name) 1242 if field_name != "wide_variable": 1243 wide_id_vars.add(str(col_name)) 1244 1245 for col_name in ranges: 1246 df_output[col_name] = range(len(df_output)) 1247 1248 for col_name in constants: 1249 df_output[col_name] = constants[col_name] 1250 1251 return df_output, wide_id_vars 1252 1253 1254def build_dataframe(args, constructor): 1255 """ 1256 Constructs a dataframe and modifies `args` in-place. 1257 1258 The argument values in `args` can be either strings corresponding to 1259 existing columns of a dataframe, or data arrays (lists, numpy arrays, 1260 pandas columns, series). 1261 1262 Parameters 1263 ---------- 1264 args : OrderedDict 1265 arguments passed to the px function and subsequently modified 1266 constructor : graph_object trace class 1267 the trace type selected for this figure 1268 """ 1269 1270 # make copies of all the fields via dict() and list() 1271 for field in args: 1272 if field in array_attrables and args[field] is not None: 1273 args[field] = ( 1274 dict(args[field]) 1275 if isinstance(args[field], dict) 1276 else list(args[field]) 1277 ) 1278 1279 # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) 1280 df_provided = args["data_frame"] is not None 1281 if df_provided and not isinstance(args["data_frame"], pd.DataFrame): 1282 args["data_frame"] = pd.DataFrame(args["data_frame"]) 1283 df_input = args["data_frame"] 1284 1285 # now we handle special cases like wide-mode or x-xor-y specification 1286 # by rearranging args to tee things up for process_args_into_dataframe to work 1287 no_x = args.get("x", None) is None 1288 no_y = args.get("y", None) is None 1289 wide_x = False if no_x else _is_col_list(df_input, args["x"]) 1290 wide_y = False if no_y else _is_col_list(df_input, args["y"]) 1291 1292 wide_mode = False 1293 var_name = None # will likely be "variable" in wide_mode 1294 wide_cross_name = None # will likely be "index" in wide_mode 1295 value_name = None # will likely be "value" in wide_mode 1296 hist2d_types = [go.Histogram2d, go.Histogram2dContour] 1297 if constructor in cartesians: 1298 if wide_x and wide_y: 1299 raise ValueError( 1300 "Cannot accept list of column references or list of columns for both `x` and `y`." 1301 ) 1302 if df_provided and no_x and no_y: 1303 wide_mode = True 1304 if isinstance(df_input.columns, pd.MultiIndex): 1305 raise TypeError( 1306 "Data frame columns is a pandas MultiIndex. " 1307 "pandas MultiIndex is not supported by plotly express " 1308 "at the moment." 1309 ) 1310 args["wide_variable"] = list(df_input.columns) 1311 var_name = df_input.columns.name 1312 if var_name in [None, "value", "index"] or var_name in df_input: 1313 var_name = "variable" 1314 if constructor == go.Funnel: 1315 wide_orientation = args.get("orientation", None) or "h" 1316 else: 1317 wide_orientation = args.get("orientation", None) or "v" 1318 args["orientation"] = wide_orientation 1319 args["wide_cross"] = None 1320 elif wide_x != wide_y: 1321 wide_mode = True 1322 args["wide_variable"] = args["y"] if wide_y else args["x"] 1323 if df_provided and args["wide_variable"] is df_input.columns: 1324 var_name = df_input.columns.name 1325 if isinstance(args["wide_variable"], pd.Index): 1326 args["wide_variable"] = list(args["wide_variable"]) 1327 if var_name in [None, "value", "index"] or ( 1328 df_provided and var_name in df_input 1329 ): 1330 var_name = "variable" 1331 if constructor == go.Histogram: 1332 wide_orientation = "v" if wide_x else "h" 1333 else: 1334 wide_orientation = "v" if wide_y else "h" 1335 args["y" if wide_y else "x"] = None 1336 args["wide_cross"] = None 1337 if not no_x and not no_y: 1338 wide_cross_name = "__x__" if wide_y else "__y__" 1339 1340 if wide_mode: 1341 value_name = _escape_col_name(df_input, "value", []) 1342 var_name = _escape_col_name(df_input, var_name, []) 1343 1344 missing_bar_dim = None 1345 if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types: 1346 if not wide_mode and (no_x != no_y): 1347 for ax in ["x", "y"]: 1348 if args.get(ax, None) is None: 1349 args[ax] = df_input.index if df_provided else Range() 1350 if constructor == go.Bar: 1351 missing_bar_dim = ax 1352 else: 1353 if args["orientation"] is None: 1354 args["orientation"] = "v" if ax == "x" else "h" 1355 if wide_mode and wide_cross_name is None: 1356 if no_x != no_y and args["orientation"] is None: 1357 args["orientation"] = "v" if no_x else "h" 1358 if df_provided: 1359 if isinstance(df_input.index, pd.MultiIndex): 1360 raise TypeError( 1361 "Data frame index is a pandas MultiIndex. " 1362 "pandas MultiIndex is not supported by plotly express " 1363 "at the moment." 1364 ) 1365 args["wide_cross"] = df_input.index 1366 else: 1367 args["wide_cross"] = Range( 1368 label=_escape_col_name(df_input, "index", [var_name, value_name]) 1369 ) 1370 1371 no_color = False 1372 if type(args.get("color", None)) == str and args["color"] == NO_COLOR: 1373 no_color = True 1374 args["color"] = None 1375 # now that things have been prepped, we do the systematic rewriting of `args` 1376 1377 df_output, wide_id_vars = process_args_into_dataframe( 1378 args, wide_mode, var_name, value_name 1379 ) 1380 1381 # now that `df_output` exists and `args` contains only references, we complete 1382 # the special-case and wide-mode handling by further rewriting args and/or mutating 1383 # df_output 1384 1385 count_name = _escape_col_name(df_output, "count", [var_name, value_name]) 1386 if not wide_mode and missing_bar_dim and constructor == go.Bar: 1387 # now that we've populated df_output, we check to see if the non-missing 1388 # dimension is categorical: if so, then setting the missing dimension to a 1389 # constant 1 is a less-insane thing to do than setting it to the index by 1390 # default and we let the normal auto-orientation-code do its thing later 1391 other_dim = "x" if missing_bar_dim == "y" else "y" 1392 if not _is_continuous(df_output, args[other_dim]): 1393 args[missing_bar_dim] = count_name 1394 df_output[count_name] = 1 1395 else: 1396 # on the other hand, if the non-missing dimension is continuous, then we 1397 # can use this information to override the normal auto-orientation code 1398 if args["orientation"] is None: 1399 args["orientation"] = "v" if missing_bar_dim == "x" else "h" 1400 1401 if constructor in hist2d_types: 1402 del args["orientation"] 1403 1404 if wide_mode: 1405 # at this point, `df_output` is semi-long/semi-wide, but we know which columns 1406 # are which, so we melt it and reassign `args` to refer to the newly-tidy 1407 # columns, keeping track of various names and manglings set up above 1408 wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars] 1409 del args["wide_variable"] 1410 if wide_cross_name == "__x__": 1411 wide_cross_name = args["x"] 1412 elif wide_cross_name == "__y__": 1413 wide_cross_name = args["y"] 1414 else: 1415 wide_cross_name = args["wide_cross"] 1416 del args["wide_cross"] 1417 dtype = None 1418 for v in wide_value_vars: 1419 v_dtype = df_output[v].dtype.kind 1420 v_dtype = "number" if v_dtype in ["i", "f", "u"] else v_dtype 1421 if dtype is None: 1422 dtype = v_dtype 1423 elif dtype != v_dtype: 1424 raise ValueError( 1425 "Plotly Express cannot process wide-form data with columns of different type." 1426 ) 1427 df_output = df_output.melt( 1428 id_vars=wide_id_vars, 1429 value_vars=wide_value_vars, 1430 var_name=var_name, 1431 value_name=value_name, 1432 ) 1433 assert len(df_output.columns) == len(set(df_output.columns)), ( 1434 "Wide-mode name-inference failure, likely due to a internal bug. " 1435 "Please report this to " 1436 "https://github.com/plotly/plotly.py/issues/new and we will try to " 1437 "replicate and fix it." 1438 ) 1439 df_output[var_name] = df_output[var_name].astype(str) 1440 orient_v = wide_orientation == "v" 1441 1442 if constructor in [go.Scatter, go.Funnel] + hist2d_types: 1443 args["x" if orient_v else "y"] = wide_cross_name 1444 args["y" if orient_v else "x"] = value_name 1445 if constructor != go.Histogram2d: 1446 args["color"] = args["color"] or var_name 1447 if "line_group" in args: 1448 args["line_group"] = args["line_group"] or var_name 1449 if constructor == go.Bar: 1450 if _is_continuous(df_output, value_name): 1451 args["x" if orient_v else "y"] = wide_cross_name 1452 args["y" if orient_v else "x"] = value_name 1453 args["color"] = args["color"] or var_name 1454 else: 1455 args["x" if orient_v else "y"] = value_name 1456 args["y" if orient_v else "x"] = count_name 1457 df_output[count_name] = 1 1458 args["color"] = args["color"] or var_name 1459 if constructor in [go.Violin, go.Box]: 1460 args["x" if orient_v else "y"] = wide_cross_name or var_name 1461 args["y" if orient_v else "x"] = value_name 1462 if constructor == go.Histogram: 1463 args["x" if orient_v else "y"] = value_name 1464 args["y" if orient_v else "x"] = wide_cross_name 1465 args["color"] = args["color"] or var_name 1466 if no_color: 1467 args["color"] = None 1468 args["data_frame"] = df_output 1469 return args 1470 1471 1472def _check_dataframe_all_leaves(df): 1473 df_sorted = df.sort_values(by=list(df.columns)) 1474 null_mask = df_sorted.isnull() 1475 df_sorted = df_sorted.astype(str) 1476 null_indices = np.nonzero(null_mask.any(axis=1).values)[0] 1477 for null_row_index in null_indices: 1478 row = null_mask.iloc[null_row_index] 1479 i = np.nonzero(row.values)[0][0] 1480 if not row[i:].all(): 1481 raise ValueError( 1482 "None entries cannot have not-None children", 1483 df_sorted.iloc[null_row_index], 1484 ) 1485 df_sorted[null_mask] = "" 1486 row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1)) 1487 for i, row in enumerate(row_strings[:-1]): 1488 if row_strings[i + 1] in row and (i + 1) in null_indices: 1489 raise ValueError( 1490 "Non-leaves rows are not permitted in the dataframe \n", 1491 df_sorted.iloc[i + 1], 1492 "is not a leaf.", 1493 ) 1494 1495 1496def process_dataframe_hierarchy(args): 1497 """ 1498 Build dataframe for sunburst or treemap when the path argument is provided. 1499 """ 1500 df = args["data_frame"] 1501 path = args["path"][::-1] 1502 _check_dataframe_all_leaves(df[path[::-1]]) 1503 discrete_color = False 1504 1505 new_path = [] 1506 for col_name in path: 1507 new_col_name = col_name + "_path_copy" 1508 new_path.append(new_col_name) 1509 df[new_col_name] = df[col_name] 1510 path = new_path 1511 # ------------ Define aggregation functions -------------------------------- 1512 1513 def aggfunc_discrete(x): 1514 uniques = x.unique() 1515 if len(uniques) == 1: 1516 return uniques[0] 1517 else: 1518 return "(?)" 1519 1520 agg_f = {} 1521 aggfunc_color = None 1522 if args["values"]: 1523 try: 1524 df[args["values"]] = pd.to_numeric(df[args["values"]]) 1525 except ValueError: 1526 raise ValueError( 1527 "Column `%s` of `df` could not be converted to a numerical data type." 1528 % args["values"] 1529 ) 1530 1531 if args["color"]: 1532 if args["color"] == args["values"]: 1533 new_value_col_name = args["values"] + "_sum" 1534 df[new_value_col_name] = df[args["values"]] 1535 args["values"] = new_value_col_name 1536 count_colname = args["values"] 1537 else: 1538 # we need a count column for the first groupby and the weighted mean of color 1539 # trick to be sure the col name is unused: take the sum of existing names 1540 count_colname = ( 1541 "count" 1542 if "count" not in df.columns 1543 else "".join([str(el) for el in list(df.columns)]) 1544 ) 1545 # we can modify df because it's a copy of the px argument 1546 df[count_colname] = 1 1547 args["values"] = count_colname 1548 agg_f[count_colname] = "sum" 1549 1550 if args["color"]: 1551 if not _is_continuous(df, args["color"]): 1552 aggfunc_color = aggfunc_discrete 1553 discrete_color = True 1554 else: 1555 1556 def aggfunc_continuous(x): 1557 return np.average(x, weights=df.loc[x.index, count_colname]) 1558 1559 aggfunc_color = aggfunc_continuous 1560 agg_f[args["color"]] = aggfunc_color 1561 1562 # Other columns (for color, hover_data, custom_data etc.) 1563 cols = list(set(df.columns).difference(path)) 1564 for col in cols: # for hover_data, custom_data etc. 1565 if col not in agg_f: 1566 agg_f[col] = aggfunc_discrete 1567 # Avoid collisions with reserved names - columns in the path have been copied already 1568 cols = list(set(cols) - set(["labels", "parent", "id"])) 1569 # ---------------------------------------------------------------------------- 1570 df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) 1571 # Set column type here (useful for continuous vs discrete colorscale) 1572 for col in cols: 1573 df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) 1574 for i, level in enumerate(path): 1575 df_tree = pd.DataFrame(columns=df_all_trees.columns) 1576 dfg = df.groupby(path[i:]).agg(agg_f) 1577 dfg = dfg.reset_index() 1578 # Path label massaging 1579 df_tree["labels"] = dfg[level].copy().astype(str) 1580 df_tree["parent"] = "" 1581 df_tree["id"] = dfg[level].copy().astype(str) 1582 if i < len(path) - 1: 1583 j = i + 1 1584 while j < len(path): 1585 df_tree["parent"] = ( 1586 dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"] 1587 ) 1588 df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"] 1589 j += 1 1590 1591 df_tree["parent"] = df_tree["parent"].str.rstrip("/") 1592 if cols: 1593 df_tree[cols] = dfg[cols] 1594 df_all_trees = df_all_trees.append(df_tree, ignore_index=True) 1595 1596 # we want to make sure than (?) is the first color of the sequence 1597 if args["color"] and discrete_color: 1598 sort_col_name = "sort_color_if_discrete_color" 1599 while sort_col_name in df_all_trees.columns: 1600 sort_col_name += "0" 1601 df_all_trees[sort_col_name] = df[args["color"]].astype(str) 1602 df_all_trees = df_all_trees.sort_values(by=sort_col_name) 1603 1604 # Now modify arguments 1605 args["data_frame"] = df_all_trees 1606 args["path"] = None 1607 args["ids"] = "id" 1608 args["names"] = "labels" 1609 args["parents"] = "parent" 1610 if args["color"]: 1611 if not args["hover_data"]: 1612 args["hover_data"] = [args["color"]] 1613 elif isinstance(args["hover_data"], dict): 1614 if not args["hover_data"].get(args["color"]): 1615 args["hover_data"][args["color"]] = (True, None) 1616 else: 1617 args["hover_data"].append(args["color"]) 1618 return args 1619 1620 1621def process_dataframe_timeline(args): 1622 """ 1623 Massage input for bar traces for px.timeline() 1624 """ 1625 args["is_timeline"] = True 1626 if args["x_start"] is None or args["x_end"] is None: 1627 raise ValueError("Both x_start and x_end are required") 1628 1629 try: 1630 x_start = pd.to_datetime(args["data_frame"][args["x_start"]]) 1631 x_end = pd.to_datetime(args["data_frame"][args["x_end"]]) 1632 except (ValueError, TypeError): 1633 raise TypeError( 1634 "Both x_start and x_end must refer to data convertible to datetimes." 1635 ) 1636 1637 # note that we are not adding any columns to the data frame here, so no risk of overwrite 1638 args["data_frame"][args["x_end"]] = (x_end - x_start).astype("timedelta64[ms]") 1639 args["x"] = args["x_end"] 1640 del args["x_end"] 1641 args["base"] = args["x_start"] 1642 del args["x_start"] 1643 return args 1644 1645 1646def infer_config(args, constructor, trace_patch, layout_patch): 1647 attrs = [k for k in direct_attrables + array_attrables if k in args] 1648 grouped_attrs = [] 1649 1650 # Compute sizeref 1651 sizeref = 0 1652 if "size" in args and args["size"]: 1653 sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2 1654 1655 # Compute color attributes and grouping attributes 1656 if "color" in args: 1657 if "color_continuous_scale" in args: 1658 if "color_discrete_sequence" not in args: 1659 attrs.append("color") 1660 else: 1661 if args["color"] and _is_continuous(args["data_frame"], args["color"]): 1662 attrs.append("color") 1663 args["color_is_continuous"] = True 1664 elif constructor in [go.Sunburst, go.Treemap]: 1665 attrs.append("color") 1666 args["color_is_continuous"] = False 1667 else: 1668 grouped_attrs.append("marker.color") 1669 elif "line_group" in args or constructor == go.Histogram2dContour: 1670 grouped_attrs.append("line.color") 1671 elif constructor in [go.Pie, go.Funnelarea]: 1672 attrs.append("color") 1673 if args["color"]: 1674 if args["hover_data"] is None: 1675 args["hover_data"] = [] 1676 args["hover_data"].append(args["color"]) 1677 else: 1678 grouped_attrs.append("marker.color") 1679 1680 show_colorbar = bool( 1681 "color" in attrs 1682 and args["color"] 1683 and constructor not in [go.Pie, go.Funnelarea] 1684 and ( 1685 constructor not in [go.Treemap, go.Sunburst] 1686 or args.get("color_is_continuous") 1687 ) 1688 ) 1689 else: 1690 show_colorbar = False 1691 1692 # Compute line_dash grouping attribute 1693 if "line_dash" in args: 1694 grouped_attrs.append("line.dash") 1695 1696 # Compute symbol grouping attribute 1697 if "symbol" in args: 1698 grouped_attrs.append("marker.symbol") 1699 1700 if "orientation" in args: 1701 has_x = args["x"] is not None 1702 has_y = args["y"] is not None 1703 if args["orientation"] is None: 1704 if constructor in [go.Histogram, go.Scatter]: 1705 if has_y and not has_x: 1706 args["orientation"] = "h" 1707 elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]: 1708 if has_x and not has_y: 1709 args["orientation"] = "h" 1710 1711 if args["orientation"] is None and has_x and has_y: 1712 x_is_continuous = _is_continuous(args["data_frame"], args["x"]) 1713 y_is_continuous = _is_continuous(args["data_frame"], args["y"]) 1714 if x_is_continuous and not y_is_continuous: 1715 args["orientation"] = "h" 1716 if y_is_continuous and not x_is_continuous: 1717 args["orientation"] = "v" 1718 1719 if args["orientation"] is None: 1720 args["orientation"] = "v" 1721 1722 if constructor == go.Histogram: 1723 if has_x and has_y and args["histfunc"] is None: 1724 args["histfunc"] = trace_patch["histfunc"] = "sum" 1725 1726 orientation = args["orientation"] 1727 nbins = args["nbins"] 1728 trace_patch["nbinsx"] = nbins if orientation == "v" else None 1729 trace_patch["nbinsy"] = None if orientation == "v" else nbins 1730 trace_patch["bingroup"] = "x" if orientation == "v" else "y" 1731 trace_patch["orientation"] = args["orientation"] 1732 1733 if constructor in [go.Violin, go.Box]: 1734 mode = "boxmode" if constructor == go.Box else "violinmode" 1735 if layout_patch[mode] is None and args["color"] is not None: 1736 if args["y"] == args["color"] and args["orientation"] == "h": 1737 layout_patch[mode] = "overlay" 1738 elif args["x"] == args["color"] and args["orientation"] == "v": 1739 layout_patch[mode] = "overlay" 1740 if layout_patch[mode] is None: 1741 layout_patch[mode] = "group" 1742 1743 if ( 1744 constructor == go.Histogram2d 1745 and args["z"] is not None 1746 and args["histfunc"] is None 1747 ): 1748 args["histfunc"] = trace_patch["histfunc"] = "sum" 1749 1750 if constructor in [go.Histogram2d, go.Densitymapbox]: 1751 show_colorbar = True 1752 trace_patch["coloraxis"] = "coloraxis1" 1753 1754 if "opacity" in args: 1755 if args["opacity"] is None: 1756 if "barmode" in args and args["barmode"] == "overlay": 1757 trace_patch["marker"] = dict(opacity=0.5) 1758 elif constructor in [go.Densitymapbox, go.Pie, go.Funnel, go.Funnelarea]: 1759 trace_patch["opacity"] = args["opacity"] 1760 else: 1761 trace_patch["marker"] = dict(opacity=args["opacity"]) 1762 if "line_group" in args: 1763 trace_patch["mode"] = "lines" + ("+markers+text" if args["text"] else "") 1764 elif constructor != go.Splom and ( 1765 "symbol" in args or constructor == go.Scattermapbox 1766 ): 1767 trace_patch["mode"] = "markers" + ("+text" if args["text"] else "") 1768 1769 if "line_shape" in args: 1770 trace_patch["line"] = dict(shape=args["line_shape"]) 1771 1772 if "geojson" in args: 1773 trace_patch["featureidkey"] = args["featureidkey"] 1774 trace_patch["geojson"] = ( 1775 args["geojson"] 1776 if not hasattr(args["geojson"], "__geo_interface__") # for geopandas 1777 else args["geojson"].__geo_interface__ 1778 ) 1779 1780 # Compute marginal attribute 1781 if "marginal" in args: 1782 position = "marginal_x" if args["orientation"] == "v" else "marginal_y" 1783 other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y" 1784 args[position] = args["marginal"] 1785 args[other_position] = None 1786 1787 # If both marginals and faceting are specified, faceting wins 1788 if args.get("facet_col", None) is not None and args.get("marginal_y", None): 1789 args["marginal_y"] = None 1790 1791 if args.get("facet_row", None) is not None and args.get("marginal_x", None): 1792 args["marginal_x"] = None 1793 1794 # facet_col_wrap only works if no marginals or row faceting is used 1795 if ( 1796 args.get("marginal_x", None) is not None 1797 or args.get("marginal_y", None) is not None 1798 or args.get("facet_row", None) is not None 1799 ): 1800 args["facet_col_wrap"] = 0 1801 1802 # Compute applicable grouping attributes 1803 for k in group_attrables: 1804 if k in args: 1805 grouped_attrs.append(k) 1806 1807 # Create grouped mappings 1808 grouped_mappings = [make_mapping(args, a) for a in grouped_attrs] 1809 1810 # Create trace specs 1811 trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) 1812 return trace_specs, grouped_mappings, sizeref, show_colorbar 1813 1814 1815def get_orderings(args, grouper, grouped): 1816 """ 1817 `orders` is the user-supplied ordering (with the remaining data-frame-supplied 1818 ordering appended if the column is used for grouping). It includes anything the user 1819 gave, for any variable, including values not present in the dataset. It is used 1820 downstream to set e.g. `categoryarray` for cartesian axes 1821 1822 `group_names` is the set of groups, ordered by the order above 1823 1824 `group_values` is a subset of `orders` in both keys and values. It contains a key 1825 for every grouped mapping and its values are the sorted *data* values for these 1826 mappings. 1827 """ 1828 orders = {} if "category_orders" not in args else args["category_orders"].copy() 1829 group_names = [] 1830 group_values = {} 1831 for group_name in grouped.groups: 1832 if len(grouper) == 1: 1833 group_name = (group_name,) 1834 group_names.append(group_name) 1835 for col in grouper: 1836 if col != one_group: 1837 uniques = args["data_frame"][col].unique() 1838 if col not in orders: 1839 orders[col] = list(uniques) 1840 else: 1841 for val in uniques: 1842 if val not in orders[col]: 1843 orders[col].append(val) 1844 group_values[col] = sorted(uniques, key=orders[col].index) 1845 1846 for i, col in reversed(list(enumerate(grouper))): 1847 if col != one_group: 1848 group_names = sorted( 1849 group_names, 1850 key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, 1851 ) 1852 1853 return orders, group_names, group_values 1854 1855 1856def make_figure(args, constructor, trace_patch=None, layout_patch=None): 1857 trace_patch = trace_patch or {} 1858 layout_patch = layout_patch or {} 1859 apply_default_cascade(args) 1860 1861 args = build_dataframe(args, constructor) 1862 if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None: 1863 args = process_dataframe_hierarchy(args) 1864 if constructor == "timeline": 1865 constructor = go.Bar 1866 args = process_dataframe_timeline(args) 1867 1868 trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( 1869 args, constructor, trace_patch, layout_patch 1870 ) 1871 grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] 1872 grouped = args["data_frame"].groupby(grouper, sort=False) 1873 1874 orders, sorted_group_names, sorted_group_values = get_orderings( 1875 args, grouper, grouped 1876 ) 1877 1878 col_labels = [] 1879 row_labels = [] 1880 1881 for m in grouped_mappings: 1882 if m.grouper: 1883 if m.facet == "col": 1884 prefix = get_label(args, args["facet_col"]) + "=" 1885 col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]] 1886 if m.facet == "row": 1887 prefix = get_label(args, args["facet_row"]) + "=" 1888 row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]] 1889 for val in sorted_group_values[m.grouper]: 1890 if val not in m.val_map: 1891 m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] 1892 1893 subplot_type = _subplot_type_for_trace_type(constructor().type) 1894 1895 trace_names_by_frame = {} 1896 frames = OrderedDict() 1897 trendline_rows = [] 1898 nrows = ncols = 1 1899 trace_name_labels = None 1900 for group_name in sorted_group_names: 1901 group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) 1902 mapping_labels = OrderedDict() 1903 trace_name_labels = OrderedDict() 1904 frame_name = "" 1905 for col, val, m in zip(grouper, group_name, grouped_mappings): 1906 if col != one_group: 1907 key = get_label(args, col) 1908 if not isinstance(m.val_map, IdentityMap): 1909 mapping_labels[key] = str(val) 1910 if m.show_in_trace_name: 1911 trace_name_labels[key] = str(val) 1912 if m.variable == "animation_frame": 1913 frame_name = val 1914 trace_name = ", ".join(trace_name_labels.values()) 1915 if frame_name not in trace_names_by_frame: 1916 trace_names_by_frame[frame_name] = set() 1917 trace_names = trace_names_by_frame[frame_name] 1918 1919 for trace_spec in trace_specs: 1920 # Create the trace 1921 trace = trace_spec.constructor(name=trace_name) 1922 if trace_spec.constructor not in [ 1923 go.Parcats, 1924 go.Parcoords, 1925 go.Choropleth, 1926 go.Choroplethmapbox, 1927 go.Densitymapbox, 1928 go.Histogram2d, 1929 go.Sunburst, 1930 go.Treemap, 1931 ]: 1932 trace.update( 1933 legendgroup=trace_name, 1934 showlegend=(trace_name != "" and trace_name not in trace_names), 1935 ) 1936 if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]: 1937 trace.update(alignmentgroup=True, offsetgroup=trace_name) 1938 trace_names.add(trace_name) 1939 1940 # Init subplot row/col 1941 trace._subplot_row = 1 1942 trace._subplot_col = 1 1943 1944 for i, m in enumerate(grouped_mappings): 1945 val = group_name[i] 1946 if val not in m.val_map: 1947 m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] 1948 try: 1949 m.updater(trace, m.val_map[val]) # covers most cases 1950 except ValueError: 1951 # this catches some odd cases like marginals 1952 if ( 1953 trace_spec != trace_specs[0] 1954 and trace_spec.constructor in [go.Violin, go.Box, go.Histogram] 1955 and m.variable == "symbol" 1956 ): 1957 pass 1958 elif ( 1959 trace_spec != trace_specs[0] 1960 and trace_spec.constructor in [go.Histogram] 1961 and m.variable == "color" 1962 ): 1963 trace.update(marker=dict(color=m.val_map[val])) 1964 elif ( 1965 trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox] 1966 and m.variable == "color" 1967 ): 1968 trace.update( 1969 z=[1] * len(group), 1970 colorscale=[m.val_map[val]] * 2, 1971 showscale=False, 1972 showlegend=True, 1973 ) 1974 else: 1975 raise 1976 1977 # Find row for trace, handling facet_row and marginal_x 1978 if m.facet == "row": 1979 row = m.val_map[val] 1980 else: 1981 if ( 1982 bool(args.get("marginal_x", False)) 1983 and trace_spec.marginal != "x" 1984 ): 1985 row = 2 1986 else: 1987 row = 1 1988 1989 facet_col_wrap = args.get("facet_col_wrap", 0) 1990 # Find col for trace, handling facet_col and marginal_y 1991 if m.facet == "col": 1992 col = m.val_map[val] 1993 if facet_col_wrap: # assumes no facet_row, no marginals 1994 row = 1 + ((col - 1) // facet_col_wrap) 1995 col = 1 + ((col - 1) % facet_col_wrap) 1996 else: 1997 if trace_spec.marginal == "y": 1998 col = 2 1999 else: 2000 col = 1 2001 2002 nrows = max(nrows, row) 2003 if row > 1: 2004 trace._subplot_row = row 2005 2006 ncols = max(ncols, col) 2007 if col > 1: 2008 trace._subplot_col = col 2009 if ( 2010 trace_specs[0].constructor == go.Histogram2dContour 2011 and trace_spec.constructor == go.Box 2012 and trace.line.color 2013 ): 2014 trace.update(marker=dict(color=trace.line.color)) 2015 2016 patch, fit_results = make_trace_kwargs( 2017 args, trace_spec, group, mapping_labels.copy(), sizeref 2018 ) 2019 trace.update(patch) 2020 if fit_results is not None: 2021 trendline_rows.append(mapping_labels.copy()) 2022 trendline_rows[-1]["px_fit_results"] = fit_results 2023 if frame_name not in frames: 2024 frames[frame_name] = dict(data=[], name=frame_name) 2025 frames[frame_name]["data"].append(trace) 2026 frame_list = [f for f in frames.values()] 2027 if len(frame_list) > 1: 2028 frame_list = sorted( 2029 frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"]) 2030 ) 2031 2032 if show_colorbar: 2033 colorvar = "z" if constructor in [go.Histogram2d, go.Densitymapbox] else "color" 2034 range_color = args["range_color"] or [None, None] 2035 2036 colorscale_validator = ColorscaleValidator("colorscale", "make_figure") 2037 layout_patch["coloraxis1"] = dict( 2038 colorscale=colorscale_validator.validate_coerce( 2039 args["color_continuous_scale"] 2040 ), 2041 cmid=args["color_continuous_midpoint"], 2042 cmin=range_color[0], 2043 cmax=range_color[1], 2044 colorbar=dict( 2045 title_text=get_decorated_label(args, args[colorvar], colorvar) 2046 ), 2047 ) 2048 for v in ["height", "width"]: 2049 if args[v]: 2050 layout_patch[v] = args[v] 2051 layout_patch["legend"] = dict(tracegroupgap=0) 2052 if trace_name_labels: 2053 layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels) 2054 if args["title"]: 2055 layout_patch["title_text"] = args["title"] 2056 elif args["template"].layout.margin.t is None: 2057 layout_patch["margin"] = {"t": 60} 2058 if ( 2059 "size" in args 2060 and args["size"] 2061 and args["template"].layout.legend.itemsizing is None 2062 ): 2063 layout_patch["legend"]["itemsizing"] = "constant" 2064 2065 fig = init_figure( 2066 args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels 2067 ) 2068 2069 # Position traces in subplots 2070 for frame in frame_list: 2071 for trace in frame["data"]: 2072 if isinstance(trace, go.Splom): 2073 # Special case that is not compatible with make_subplots 2074 continue 2075 2076 _set_trace_grid_reference( 2077 trace, 2078 fig.layout, 2079 fig._grid_ref, 2080 nrows - trace._subplot_row + 1, 2081 trace._subplot_col, 2082 ) 2083 2084 # Add traces, layout and frames to figure 2085 fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else []) 2086 fig.update_layout(layout_patch) 2087 if "template" in args and args["template"] is not None: 2088 fig.update_layout(template=args["template"], overwrite=True) 2089 fig.frames = frame_list if len(frames) > 1 else [] 2090 2091 fig._px_trendlines = pd.DataFrame(trendline_rows) 2092 2093 configure_axes(args, constructor, fig, orders) 2094 configure_animation_controls(args, constructor, fig) 2095 return fig 2096 2097 2098def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels): 2099 # Build subplot specs 2100 specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)] 2101 2102 # Default row/column widths uniform 2103 column_widths = [1.0] * ncols 2104 row_heights = [1.0] * nrows 2105 facet_col_wrap = args.get("facet_col_wrap", 0) 2106 2107 # Build column_widths/row_heights 2108 if subplot_type == "xy": 2109 if bool(args.get("marginal_x", False)): 2110 if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): 2111 main_size = 0.74 2112 else: 2113 main_size = 0.84 2114 2115 row_heights = [main_size] * (nrows - 1) + [1 - main_size] 2116 vertical_spacing = 0.01 2117 elif facet_col_wrap: 2118 vertical_spacing = args.get("facet_row_spacing", None) or 0.07 2119 else: 2120 vertical_spacing = args.get("facet_row_spacing", None) or 0.03 2121 2122 if bool(args.get("marginal_y", False)): 2123 if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): 2124 main_size = 0.74 2125 else: 2126 main_size = 0.84 2127 2128 column_widths = [main_size] * (ncols - 1) + [1 - main_size] 2129 horizontal_spacing = 0.005 2130 else: 2131 horizontal_spacing = args.get("facet_col_spacing", None) or 0.02 2132 else: 2133 # Other subplot types: 2134 # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None 2135 # 2136 # We can customize subplot spacing per type once we enable faceting 2137 # for all plot types 2138 if facet_col_wrap: 2139 vertical_spacing = args.get("facet_row_spacing", None) or 0.07 2140 else: 2141 vertical_spacing = args.get("facet_row_spacing", None) or 0.03 2142 horizontal_spacing = args.get("facet_col_spacing", None) or 0.02 2143 2144 if facet_col_wrap: 2145 subplot_labels = [None] * nrows * ncols 2146 while len(col_labels) < nrows * ncols: 2147 col_labels.append(None) 2148 for i in range(nrows): 2149 for j in range(ncols): 2150 subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j] 2151 2152 def _spacing_error_translator(e, direction, facet_arg): 2153 """ 2154 Translates the spacing errors thrown by the underlying make_subplots 2155 routine into one that describes an argument adjustable through px. 2156 """ 2157 if ("%s spacing" % (direction,)) in e.args[0]: 2158 e.args = ( 2159 e.args[0] 2160 + """ 2161Use the {facet_arg} argument to adjust this spacing.""".format( 2162 facet_arg=facet_arg 2163 ), 2164 ) 2165 raise e 2166 2167 # Create figure with subplots 2168 try: 2169 fig = make_subplots( 2170 rows=nrows, 2171 cols=ncols, 2172 specs=specs, 2173 shared_xaxes="all", 2174 shared_yaxes="all", 2175 row_titles=[] if facet_col_wrap else list(reversed(row_labels)), 2176 column_titles=[] if facet_col_wrap else col_labels, 2177 subplot_titles=subplot_labels if facet_col_wrap else [], 2178 horizontal_spacing=horizontal_spacing, 2179 vertical_spacing=vertical_spacing, 2180 row_heights=row_heights, 2181 column_widths=column_widths, 2182 start_cell="bottom-left", 2183 ) 2184 except ValueError as e: 2185 _spacing_error_translator(e, "Horizontal", "facet_col_spacing") 2186 _spacing_error_translator(e, "Vertical", "facet_row_spacing") 2187 2188 # Remove explicit font size of row/col titles so template can take over 2189 for annot in fig.layout.annotations: 2190 annot.update(font=None) 2191 2192 return fig 2193