1""" 2Utility routines 3""" 4from collections.abc import Mapping 5from copy import deepcopy 6import json 7import itertools 8import re 9import sys 10import traceback 11import warnings 12 13import jsonschema 14import pandas as pd 15import numpy as np 16 17from .schemapi import SchemaBase, Undefined 18 19try: 20 from pandas.api.types import infer_dtype as _infer_dtype 21except ImportError: 22 # Import for pandas < 0.20.0 23 from pandas.lib import infer_dtype as _infer_dtype 24 25 26def infer_dtype(value): 27 """Infer the dtype of the value. 28 29 This is a compatibility function for pandas infer_dtype, 30 with skipna=False regardless of the pandas version. 31 """ 32 if not hasattr(infer_dtype, "_supports_skipna"): 33 try: 34 _infer_dtype([1], skipna=False) 35 except TypeError: 36 # pandas < 0.21.0 don't support skipna keyword 37 infer_dtype._supports_skipna = False 38 else: 39 infer_dtype._supports_skipna = True 40 if infer_dtype._supports_skipna: 41 return _infer_dtype(value, skipna=False) 42 else: 43 return _infer_dtype(value) 44 45 46TYPECODE_MAP = { 47 "ordinal": "O", 48 "nominal": "N", 49 "quantitative": "Q", 50 "temporal": "T", 51 "geojson": "G", 52} 53 54INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()} 55 56 57# aggregates from vega-lite version 4.6.0 58AGGREGATES = [ 59 "argmax", 60 "argmin", 61 "average", 62 "count", 63 "distinct", 64 "max", 65 "mean", 66 "median", 67 "min", 68 "missing", 69 "product", 70 "q1", 71 "q3", 72 "ci0", 73 "ci1", 74 "stderr", 75 "stdev", 76 "stdevp", 77 "sum", 78 "valid", 79 "values", 80 "variance", 81 "variancep", 82] 83 84# window aggregates from vega-lite version 4.6.0 85WINDOW_AGGREGATES = [ 86 "row_number", 87 "rank", 88 "dense_rank", 89 "percent_rank", 90 "cume_dist", 91 "ntile", 92 "lag", 93 "lead", 94 "first_value", 95 "last_value", 96 "nth_value", 97] 98 99# timeUnits from vega-lite version 4.6.0 100TIMEUNITS = [ 101 "utcyear", 102 "utcquarter", 103 "utcmonth", 104 "utcday", 105 "utcdate", 106 "utchours", 107 "utcminutes", 108 "utcseconds", 109 "utcmilliseconds", 110 "utcyearquarter", 111 "utcyearquartermonth", 112 "utcyearmonth", 113 "utcyearmonthdate", 114 "utcyearmonthdatehours", 115 "utcyearmonthdatehoursminutes", 116 "utcyearmonthdatehoursminutesseconds", 117 "utcquartermonth", 118 "utcmonthdate", 119 "utcmonthdatehours", 120 "utchoursminutes", 121 "utchoursminutesseconds", 122 "utcminutesseconds", 123 "utcsecondsmilliseconds", 124 "year", 125 "quarter", 126 "month", 127 "day", 128 "date", 129 "hours", 130 "minutes", 131 "seconds", 132 "milliseconds", 133 "yearquarter", 134 "yearquartermonth", 135 "yearmonth", 136 "yearmonthdate", 137 "yearmonthdatehours", 138 "yearmonthdatehoursminutes", 139 "yearmonthdatehoursminutesseconds", 140 "quartermonth", 141 "monthdate", 142 "monthdatehours", 143 "hoursminutes", 144 "hoursminutesseconds", 145 "minutesseconds", 146 "secondsmilliseconds", 147] 148 149 150def infer_vegalite_type(data): 151 """ 152 From an array-like input, infer the correct vega typecode 153 ('ordinal', 'nominal', 'quantitative', or 'temporal') 154 155 Parameters 156 ---------- 157 data: Numpy array or Pandas Series 158 """ 159 # Otherwise, infer based on the dtype of the input 160 typ = infer_dtype(data) 161 162 # TODO: Once this returns 'O', please update test_select_x and test_select_y in test_api.py 163 164 if typ in [ 165 "floating", 166 "mixed-integer-float", 167 "integer", 168 "mixed-integer", 169 "complex", 170 ]: 171 return "quantitative" 172 elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: 173 return "nominal" 174 elif typ in [ 175 "datetime", 176 "datetime64", 177 "timedelta", 178 "timedelta64", 179 "date", 180 "time", 181 "period", 182 ]: 183 return "temporal" 184 else: 185 warnings.warn( 186 "I don't know how to infer vegalite type from '{}'. " 187 "Defaulting to nominal.".format(typ) 188 ) 189 return "nominal" 190 191 192def merge_props_geom(feat): 193 """ 194 Merge properties with geometry 195 * Overwrites 'type' and 'geometry' entries if existing 196 """ 197 198 geom = {k: feat[k] for k in ("type", "geometry")} 199 try: 200 feat["properties"].update(geom) 201 props_geom = feat["properties"] 202 except (AttributeError, KeyError): 203 # AttributeError when 'properties' equals None 204 # KeyError when 'properties' is non-existing 205 props_geom = geom 206 207 return props_geom 208 209 210def sanitize_geo_interface(geo): 211 """Santize a geo_interface to prepare it for serialization. 212 213 * Make a copy 214 * Convert type array or _Array to list 215 * Convert tuples to lists (using json.loads/dumps) 216 * Merge properties with geometry 217 """ 218 219 geo = deepcopy(geo) 220 221 # convert type _Array or array to list 222 for key in geo.keys(): 223 if str(type(geo[key]).__name__).startswith(("_Array", "array")): 224 geo[key] = geo[key].tolist() 225 226 # convert (nested) tuples to lists 227 geo = json.loads(json.dumps(geo)) 228 229 # sanitize features 230 if geo["type"] == "FeatureCollection": 231 geo = geo["features"] 232 if len(geo) > 0: 233 for idx, feat in enumerate(geo): 234 geo[idx] = merge_props_geom(feat) 235 elif geo["type"] == "Feature": 236 geo = merge_props_geom(geo) 237 else: 238 geo = {"type": "Feature", "geometry": geo} 239 240 return geo 241 242 243def sanitize_dataframe(df): # noqa: C901 244 """Sanitize a DataFrame to prepare it for serialization. 245 246 * Make a copy 247 * Convert RangeIndex columns to strings 248 * Raise ValueError if column names are not strings 249 * Raise ValueError if it has a hierarchical index. 250 * Convert categoricals to strings. 251 * Convert np.bool_ dtypes to Python bool objects 252 * Convert np.int dtypes to Python int objects 253 * Convert floats to objects and replace NaNs/infs with None. 254 * Convert DateTime dtypes into appropriate string representations 255 * Convert Nullable integers to objects and replace NaN with None 256 * Convert Nullable boolean to objects and replace NaN with None 257 * convert dedicated string column to objects and replace NaN with None 258 * Raise a ValueError for TimeDelta dtypes 259 """ 260 df = df.copy() 261 262 if isinstance(df.columns, pd.RangeIndex): 263 df.columns = df.columns.astype(str) 264 265 for col in df.columns: 266 if not isinstance(col, str): 267 raise ValueError( 268 "Dataframe contains invalid column name: {0!r}. " 269 "Column names must be strings".format(col) 270 ) 271 272 if isinstance(df.index, pd.MultiIndex): 273 raise ValueError("Hierarchical indices not supported") 274 if isinstance(df.columns, pd.MultiIndex): 275 raise ValueError("Hierarchical indices not supported") 276 277 def to_list_if_array(val): 278 if isinstance(val, np.ndarray): 279 return val.tolist() 280 else: 281 return val 282 283 for col_name, dtype in df.dtypes.iteritems(): 284 if str(dtype) == "category": 285 # XXXX: work around bug in to_json for categorical types 286 # https://github.com/pydata/pandas/issues/10778 287 col = df[col_name].astype(object) 288 df[col_name] = col.where(col.notnull(), None) 289 elif str(dtype) == "string": 290 # dedicated string datatype (since 1.0) 291 # https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type 292 col = df[col_name].astype(object) 293 df[col_name] = col.where(col.notnull(), None) 294 elif str(dtype) == "bool": 295 # convert numpy bools to objects; np.bool is not JSON serializable 296 df[col_name] = df[col_name].astype(object) 297 elif str(dtype) == "boolean": 298 # dedicated boolean datatype (since 1.0) 299 # https://pandas.io/docs/user_guide/boolean.html 300 col = df[col_name].astype(object) 301 df[col_name] = col.where(col.notnull(), None) 302 elif str(dtype).startswith("datetime"): 303 # Convert datetimes to strings. This needs to be a full ISO string 304 # with time, which is why we cannot use ``col.astype(str)``. 305 # This is because Javascript parses date-only times in UTC, but 306 # parses full ISO-8601 dates as local time, and dates in Vega and 307 # Vega-Lite are displayed in local time by default. 308 # (see https://github.com/altair-viz/altair/issues/1027) 309 df[col_name] = ( 310 df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") 311 ) 312 elif str(dtype).startswith("timedelta"): 313 raise ValueError( 314 'Field "{col_name}" has type "{dtype}" which is ' 315 "not supported by Altair. Please convert to " 316 "either a timestamp or a numerical value." 317 "".format(col_name=col_name, dtype=dtype) 318 ) 319 elif str(dtype).startswith("geometry"): 320 # geopandas >=0.6.1 uses the dtype geometry. Continue here 321 # otherwise it will give an error on np.issubdtype(dtype, np.integer) 322 continue 323 elif str(dtype) in { 324 "Int8", 325 "Int16", 326 "Int32", 327 "Int64", 328 "UInt8", 329 "UInt16", 330 "UInt32", 331 "UInt64", 332 }: # nullable integer datatypes (since 24.0) 333 # https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support 334 col = df[col_name].astype(object) 335 df[col_name] = col.where(col.notnull(), None) 336 elif np.issubdtype(dtype, np.integer): 337 # convert integers to objects; np.int is not JSON serializable 338 df[col_name] = df[col_name].astype(object) 339 elif np.issubdtype(dtype, np.floating): 340 # For floats, convert to Python float: np.float is not JSON serializable 341 # Also convert NaN/inf values to null, as they are not JSON serializable 342 col = df[col_name] 343 bad_values = col.isnull() | np.isinf(col) 344 df[col_name] = col.astype(object).where(~bad_values, None) 345 elif dtype == object: 346 # Convert numpy arrays saved as objects to lists 347 # Arrays are not JSON serializable 348 col = df[col_name].apply(to_list_if_array, convert_dtype=False) 349 df[col_name] = col.where(col.notnull(), None) 350 return df 351 352 353def parse_shorthand( 354 shorthand, 355 data=None, 356 parse_aggregates=True, 357 parse_window_ops=False, 358 parse_timeunits=True, 359 parse_types=True, 360): 361 """General tool to parse shorthand values 362 363 These are of the form: 364 365 - "col_name" 366 - "col_name:O" 367 - "average(col_name)" 368 - "average(col_name):O" 369 370 Optionally, a dataframe may be supplied, from which the type 371 will be inferred if not specified in the shorthand. 372 373 Parameters 374 ---------- 375 shorthand : dict or string 376 The shorthand representation to be parsed 377 data : DataFrame, optional 378 If specified and of type DataFrame, then use these values to infer the 379 column type if not provided by the shorthand. 380 parse_aggregates : boolean 381 If True (default), then parse aggregate functions within the shorthand. 382 parse_window_ops : boolean 383 If True then parse window operations within the shorthand (default:False) 384 parse_timeunits : boolean 385 If True (default), then parse timeUnits from within the shorthand 386 parse_types : boolean 387 If True (default), then parse typecodes within the shorthand 388 389 Returns 390 ------- 391 attrs : dict 392 a dictionary of attributes extracted from the shorthand 393 394 Examples 395 -------- 396 >>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], 397 ... 'bar': [1, 2, 3, 4]}) 398 399 >>> parse_shorthand('name') == {'field': 'name'} 400 True 401 402 >>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'} 403 True 404 405 >>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'} 406 True 407 408 >>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'} 409 True 410 411 >>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'} 412 True 413 414 >>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'} 415 True 416 417 >>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'} 418 True 419 420 >>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'} 421 True 422 423 >>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'} 424 True 425 426 >>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'} 427 True 428 429 >>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'} 430 True 431 432 >>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} 433 True 434 """ 435 if not shorthand: 436 return {} 437 438 valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP) 439 440 units = dict( 441 field="(?P<field>.*)", 442 type="(?P<type>{})".format("|".join(valid_typecodes)), 443 agg_count="(?P<aggregate>count)", 444 op_count="(?P<op>count)", 445 aggregate="(?P<aggregate>{})".format("|".join(AGGREGATES)), 446 window_op="(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), 447 timeUnit="(?P<timeUnit>{})".format("|".join(TIMEUNITS)), 448 ) 449 450 patterns = [] 451 452 if parse_aggregates: 453 patterns.extend([r"{agg_count}\(\)"]) 454 patterns.extend([r"{aggregate}\({field}\)"]) 455 if parse_window_ops: 456 patterns.extend([r"{op_count}\(\)"]) 457 patterns.extend([r"{window_op}\({field}\)"]) 458 if parse_timeunits: 459 patterns.extend([r"{timeUnit}\({field}\)"]) 460 461 patterns.extend([r"{field}"]) 462 463 if parse_types: 464 patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns))) 465 466 regexps = ( 467 re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns 468 ) 469 470 # find matches depending on valid fields passed 471 if isinstance(shorthand, dict): 472 attrs = shorthand 473 else: 474 attrs = next( 475 exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand) 476 ) 477 478 # Handle short form of the type expression 479 if "type" in attrs: 480 attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"]) 481 482 # counts are quantitative by default 483 if attrs == {"aggregate": "count"}: 484 attrs["type"] = "quantitative" 485 486 # times are temporal by default 487 if "timeUnit" in attrs and "type" not in attrs: 488 attrs["type"] = "temporal" 489 490 # if data is specified and type is not, infer type from data 491 if isinstance(data, pd.DataFrame) and "type" not in attrs: 492 if "field" in attrs and attrs["field"] in data.columns: 493 attrs["type"] = infer_vegalite_type(data[attrs["field"]]) 494 return attrs 495 496 497def use_signature(Obj): 498 """Apply call signature and documentation of Obj to the decorated method""" 499 500 def decorate(f): 501 # call-signature of f is exposed via __wrapped__. 502 # we want it to mimic Obj.__init__ 503 f.__wrapped__ = Obj.__init__ 504 f._uses_signature = Obj 505 506 # Supplement the docstring of f with information from Obj 507 if Obj.__doc__: 508 doclines = Obj.__doc__.splitlines() 509 if f.__doc__: 510 doc = f.__doc__ + "\n".join(doclines[1:]) 511 else: 512 doc = "\n".join(doclines) 513 try: 514 f.__doc__ = doc 515 except AttributeError: 516 # __doc__ is not modifiable for classes in Python < 3.3 517 pass 518 519 return f 520 521 return decorate 522 523 524def update_subtraits(obj, attrs, **kwargs): 525 """Recursively update sub-traits without overwriting other traits""" 526 # TODO: infer keywords from args 527 if not kwargs: 528 return obj 529 530 # obj can be a SchemaBase object or a dict 531 if obj is Undefined: 532 obj = dct = {} 533 elif isinstance(obj, SchemaBase): 534 dct = obj._kwds 535 else: 536 dct = obj 537 538 if isinstance(attrs, str): 539 attrs = (attrs,) 540 541 if len(attrs) == 0: 542 dct.update(kwargs) 543 else: 544 attr = attrs[0] 545 trait = dct.get(attr, Undefined) 546 if trait is Undefined: 547 trait = dct[attr] = {} 548 dct[attr] = update_subtraits(trait, attrs[1:], **kwargs) 549 return obj 550 551 552def update_nested(original, update, copy=False): 553 """Update nested dictionaries 554 555 Parameters 556 ---------- 557 original : dict 558 the original (nested) dictionary, which will be updated in-place 559 update : dict 560 the nested dictionary of updates 561 copy : bool, default False 562 if True, then copy the original dictionary rather than modifying it 563 564 Returns 565 ------- 566 original : dict 567 a reference to the (modified) original dict 568 569 Examples 570 -------- 571 >>> original = {'x': {'b': 2, 'c': 4}} 572 >>> update = {'x': {'b': 5, 'd': 6}, 'y': 40} 573 >>> update_nested(original, update) # doctest: +SKIP 574 {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} 575 >>> original # doctest: +SKIP 576 {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} 577 """ 578 if copy: 579 original = deepcopy(original) 580 for key, val in update.items(): 581 if isinstance(val, Mapping): 582 orig_val = original.get(key, {}) 583 if isinstance(orig_val, Mapping): 584 original[key] = update_nested(orig_val, val) 585 else: 586 original[key] = val 587 else: 588 original[key] = val 589 return original 590 591 592def display_traceback(in_ipython=True): 593 exc_info = sys.exc_info() 594 595 if in_ipython: 596 from IPython.core.getipython import get_ipython 597 598 ip = get_ipython() 599 else: 600 ip = None 601 602 if ip is not None: 603 ip.showtraceback(exc_info) 604 else: 605 traceback.print_exception(*exc_info) 606 607 608def infer_encoding_types(args, kwargs, channels): 609 """Infer typed keyword arguments for args and kwargs 610 611 Parameters 612 ---------- 613 args : tuple 614 List of function args 615 kwargs : dict 616 Dict of function kwargs 617 channels : module 618 The module containing all altair encoding channel classes. 619 620 Returns 621 ------- 622 kwargs : dict 623 All args and kwargs in a single dict, with keys and types 624 based on the channels mapping. 625 """ 626 # Construct a dictionary of channel type to encoding name 627 # TODO: cache this somehow? 628 channel_objs = (getattr(channels, name) for name in dir(channels)) 629 channel_objs = ( 630 c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) 631 ) 632 channel_to_name = {c: c._encoding_name for c in channel_objs} 633 name_to_channel = {} 634 for chan, name in channel_to_name.items(): 635 chans = name_to_channel.setdefault(name, {}) 636 key = "value" if chan.__name__.endswith("Value") else "field" 637 chans[key] = chan 638 639 # First use the mapping to convert args to kwargs based on their types. 640 for arg in args: 641 if isinstance(arg, (list, tuple)) and len(arg) > 0: 642 type_ = type(arg[0]) 643 else: 644 type_ = type(arg) 645 646 encoding = channel_to_name.get(type_, None) 647 if encoding is None: 648 raise NotImplementedError("positional of type {}" "".format(type_)) 649 if encoding in kwargs: 650 raise ValueError("encoding {} specified twice.".format(encoding)) 651 kwargs[encoding] = arg 652 653 def _wrap_in_channel_class(obj, encoding): 654 try: 655 condition = obj["condition"] 656 except (KeyError, TypeError): 657 pass 658 else: 659 if condition is not Undefined: 660 obj = obj.copy() 661 obj["condition"] = _wrap_in_channel_class(condition, encoding) 662 663 if isinstance(obj, SchemaBase): 664 return obj 665 666 if isinstance(obj, str): 667 obj = {"shorthand": obj} 668 669 if isinstance(obj, (list, tuple)): 670 return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] 671 672 if encoding not in name_to_channel: 673 warnings.warn("Unrecognized encoding channel '{}'".format(encoding)) 674 return obj 675 676 classes = name_to_channel[encoding] 677 cls = classes["value"] if "value" in obj else classes["field"] 678 679 try: 680 # Don't force validation here; some objects won't be valid until 681 # they're created in the context of a chart. 682 return cls.from_dict(obj, validate=False) 683 except jsonschema.ValidationError: 684 # our attempts at finding the correct class have failed 685 return obj 686 687 return { 688 encoding: _wrap_in_channel_class(obj, encoding) 689 for encoding, obj in kwargs.items() 690 } 691