1"""
2Utility routines
3"""
4from collections.abc import Mapping
5from copy import deepcopy
6import json
7import itertools
8import re
9import sys
10import traceback
11import warnings
12
13import jsonschema
14import pandas as pd
15import numpy as np
16
17from .schemapi import SchemaBase, Undefined
18
19try:
20    from pandas.api.types import infer_dtype as _infer_dtype
21except ImportError:
22    # Import for pandas < 0.20.0
23    from pandas.lib import infer_dtype as _infer_dtype
24
25
26def infer_dtype(value):
27    """Infer the dtype of the value.
28
29    This is a compatibility function for pandas infer_dtype,
30    with skipna=False regardless of the pandas version.
31    """
32    if not hasattr(infer_dtype, "_supports_skipna"):
33        try:
34            _infer_dtype([1], skipna=False)
35        except TypeError:
36            # pandas < 0.21.0 don't support skipna keyword
37            infer_dtype._supports_skipna = False
38        else:
39            infer_dtype._supports_skipna = True
40    if infer_dtype._supports_skipna:
41        return _infer_dtype(value, skipna=False)
42    else:
43        return _infer_dtype(value)
44
45
46TYPECODE_MAP = {
47    "ordinal": "O",
48    "nominal": "N",
49    "quantitative": "Q",
50    "temporal": "T",
51    "geojson": "G",
52}
53
54INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()}
55
56
57# aggregates from vega-lite version 4.6.0
58AGGREGATES = [
59    "argmax",
60    "argmin",
61    "average",
62    "count",
63    "distinct",
64    "max",
65    "mean",
66    "median",
67    "min",
68    "missing",
69    "product",
70    "q1",
71    "q3",
72    "ci0",
73    "ci1",
74    "stderr",
75    "stdev",
76    "stdevp",
77    "sum",
78    "valid",
79    "values",
80    "variance",
81    "variancep",
82]
83
84# window aggregates from vega-lite version 4.6.0
85WINDOW_AGGREGATES = [
86    "row_number",
87    "rank",
88    "dense_rank",
89    "percent_rank",
90    "cume_dist",
91    "ntile",
92    "lag",
93    "lead",
94    "first_value",
95    "last_value",
96    "nth_value",
97]
98
99# timeUnits from vega-lite version 4.6.0
100TIMEUNITS = [
101    "utcyear",
102    "utcquarter",
103    "utcmonth",
104    "utcday",
105    "utcdate",
106    "utchours",
107    "utcminutes",
108    "utcseconds",
109    "utcmilliseconds",
110    "utcyearquarter",
111    "utcyearquartermonth",
112    "utcyearmonth",
113    "utcyearmonthdate",
114    "utcyearmonthdatehours",
115    "utcyearmonthdatehoursminutes",
116    "utcyearmonthdatehoursminutesseconds",
117    "utcquartermonth",
118    "utcmonthdate",
119    "utcmonthdatehours",
120    "utchoursminutes",
121    "utchoursminutesseconds",
122    "utcminutesseconds",
123    "utcsecondsmilliseconds",
124    "year",
125    "quarter",
126    "month",
127    "day",
128    "date",
129    "hours",
130    "minutes",
131    "seconds",
132    "milliseconds",
133    "yearquarter",
134    "yearquartermonth",
135    "yearmonth",
136    "yearmonthdate",
137    "yearmonthdatehours",
138    "yearmonthdatehoursminutes",
139    "yearmonthdatehoursminutesseconds",
140    "quartermonth",
141    "monthdate",
142    "monthdatehours",
143    "hoursminutes",
144    "hoursminutesseconds",
145    "minutesseconds",
146    "secondsmilliseconds",
147]
148
149
150def infer_vegalite_type(data):
151    """
152    From an array-like input, infer the correct vega typecode
153    ('ordinal', 'nominal', 'quantitative', or 'temporal')
154
155    Parameters
156    ----------
157    data: Numpy array or Pandas Series
158    """
159    # Otherwise, infer based on the dtype of the input
160    typ = infer_dtype(data)
161
162    # TODO: Once this returns 'O', please update test_select_x and test_select_y in test_api.py
163
164    if typ in [
165        "floating",
166        "mixed-integer-float",
167        "integer",
168        "mixed-integer",
169        "complex",
170    ]:
171        return "quantitative"
172    elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
173        return "nominal"
174    elif typ in [
175        "datetime",
176        "datetime64",
177        "timedelta",
178        "timedelta64",
179        "date",
180        "time",
181        "period",
182    ]:
183        return "temporal"
184    else:
185        warnings.warn(
186            "I don't know how to infer vegalite type from '{}'.  "
187            "Defaulting to nominal.".format(typ)
188        )
189        return "nominal"
190
191
192def merge_props_geom(feat):
193    """
194    Merge properties with geometry
195    * Overwrites 'type' and 'geometry' entries if existing
196    """
197
198    geom = {k: feat[k] for k in ("type", "geometry")}
199    try:
200        feat["properties"].update(geom)
201        props_geom = feat["properties"]
202    except (AttributeError, KeyError):
203        # AttributeError when 'properties' equals None
204        # KeyError when 'properties' is non-existing
205        props_geom = geom
206
207    return props_geom
208
209
210def sanitize_geo_interface(geo):
211    """Santize a geo_interface to prepare it for serialization.
212
213    * Make a copy
214    * Convert type array or _Array to list
215    * Convert tuples to lists (using json.loads/dumps)
216    * Merge properties with geometry
217    """
218
219    geo = deepcopy(geo)
220
221    # convert type _Array or array to list
222    for key in geo.keys():
223        if str(type(geo[key]).__name__).startswith(("_Array", "array")):
224            geo[key] = geo[key].tolist()
225
226    # convert (nested) tuples to lists
227    geo = json.loads(json.dumps(geo))
228
229    # sanitize features
230    if geo["type"] == "FeatureCollection":
231        geo = geo["features"]
232        if len(geo) > 0:
233            for idx, feat in enumerate(geo):
234                geo[idx] = merge_props_geom(feat)
235    elif geo["type"] == "Feature":
236        geo = merge_props_geom(geo)
237    else:
238        geo = {"type": "Feature", "geometry": geo}
239
240    return geo
241
242
243def sanitize_dataframe(df):  # noqa: C901
244    """Sanitize a DataFrame to prepare it for serialization.
245
246    * Make a copy
247    * Convert RangeIndex columns to strings
248    * Raise ValueError if column names are not strings
249    * Raise ValueError if it has a hierarchical index.
250    * Convert categoricals to strings.
251    * Convert np.bool_ dtypes to Python bool objects
252    * Convert np.int dtypes to Python int objects
253    * Convert floats to objects and replace NaNs/infs with None.
254    * Convert DateTime dtypes into appropriate string representations
255    * Convert Nullable integers to objects and replace NaN with None
256    * Convert Nullable boolean to objects and replace NaN with None
257    * convert dedicated string column to objects and replace NaN with None
258    * Raise a ValueError for TimeDelta dtypes
259    """
260    df = df.copy()
261
262    if isinstance(df.columns, pd.RangeIndex):
263        df.columns = df.columns.astype(str)
264
265    for col in df.columns:
266        if not isinstance(col, str):
267            raise ValueError(
268                "Dataframe contains invalid column name: {0!r}. "
269                "Column names must be strings".format(col)
270            )
271
272    if isinstance(df.index, pd.MultiIndex):
273        raise ValueError("Hierarchical indices not supported")
274    if isinstance(df.columns, pd.MultiIndex):
275        raise ValueError("Hierarchical indices not supported")
276
277    def to_list_if_array(val):
278        if isinstance(val, np.ndarray):
279            return val.tolist()
280        else:
281            return val
282
283    for col_name, dtype in df.dtypes.iteritems():
284        if str(dtype) == "category":
285            # XXXX: work around bug in to_json for categorical types
286            # https://github.com/pydata/pandas/issues/10778
287            col = df[col_name].astype(object)
288            df[col_name] = col.where(col.notnull(), None)
289        elif str(dtype) == "string":
290            # dedicated string datatype (since 1.0)
291            # https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type
292            col = df[col_name].astype(object)
293            df[col_name] = col.where(col.notnull(), None)
294        elif str(dtype) == "bool":
295            # convert numpy bools to objects; np.bool is not JSON serializable
296            df[col_name] = df[col_name].astype(object)
297        elif str(dtype) == "boolean":
298            # dedicated boolean datatype (since 1.0)
299            # https://pandas.io/docs/user_guide/boolean.html
300            col = df[col_name].astype(object)
301            df[col_name] = col.where(col.notnull(), None)
302        elif str(dtype).startswith("datetime"):
303            # Convert datetimes to strings. This needs to be a full ISO string
304            # with time, which is why we cannot use ``col.astype(str)``.
305            # This is because Javascript parses date-only times in UTC, but
306            # parses full ISO-8601 dates as local time, and dates in Vega and
307            # Vega-Lite are displayed in local time by default.
308            # (see https://github.com/altair-viz/altair/issues/1027)
309            df[col_name] = (
310                df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "")
311            )
312        elif str(dtype).startswith("timedelta"):
313            raise ValueError(
314                'Field "{col_name}" has type "{dtype}" which is '
315                "not supported by Altair. Please convert to "
316                "either a timestamp or a numerical value."
317                "".format(col_name=col_name, dtype=dtype)
318            )
319        elif str(dtype).startswith("geometry"):
320            # geopandas >=0.6.1 uses the dtype geometry. Continue here
321            # otherwise it will give an error on np.issubdtype(dtype, np.integer)
322            continue
323        elif str(dtype) in {
324            "Int8",
325            "Int16",
326            "Int32",
327            "Int64",
328            "UInt8",
329            "UInt16",
330            "UInt32",
331            "UInt64",
332        }:  # nullable integer datatypes (since 24.0)
333            # https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support
334            col = df[col_name].astype(object)
335            df[col_name] = col.where(col.notnull(), None)
336        elif np.issubdtype(dtype, np.integer):
337            # convert integers to objects; np.int is not JSON serializable
338            df[col_name] = df[col_name].astype(object)
339        elif np.issubdtype(dtype, np.floating):
340            # For floats, convert to Python float: np.float is not JSON serializable
341            # Also convert NaN/inf values to null, as they are not JSON serializable
342            col = df[col_name]
343            bad_values = col.isnull() | np.isinf(col)
344            df[col_name] = col.astype(object).where(~bad_values, None)
345        elif dtype == object:
346            # Convert numpy arrays saved as objects to lists
347            # Arrays are not JSON serializable
348            col = df[col_name].apply(to_list_if_array, convert_dtype=False)
349            df[col_name] = col.where(col.notnull(), None)
350    return df
351
352
353def parse_shorthand(
354    shorthand,
355    data=None,
356    parse_aggregates=True,
357    parse_window_ops=False,
358    parse_timeunits=True,
359    parse_types=True,
360):
361    """General tool to parse shorthand values
362
363    These are of the form:
364
365    - "col_name"
366    - "col_name:O"
367    - "average(col_name)"
368    - "average(col_name):O"
369
370    Optionally, a dataframe may be supplied, from which the type
371    will be inferred if not specified in the shorthand.
372
373    Parameters
374    ----------
375    shorthand : dict or string
376        The shorthand representation to be parsed
377    data : DataFrame, optional
378        If specified and of type DataFrame, then use these values to infer the
379        column type if not provided by the shorthand.
380    parse_aggregates : boolean
381        If True (default), then parse aggregate functions within the shorthand.
382    parse_window_ops : boolean
383        If True then parse window operations within the shorthand (default:False)
384    parse_timeunits : boolean
385        If True (default), then parse timeUnits from within the shorthand
386    parse_types : boolean
387        If True (default), then parse typecodes within the shorthand
388
389    Returns
390    -------
391    attrs : dict
392        a dictionary of attributes extracted from the shorthand
393
394    Examples
395    --------
396    >>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'],
397    ...                      'bar': [1, 2, 3, 4]})
398
399    >>> parse_shorthand('name') == {'field': 'name'}
400    True
401
402    >>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'}
403    True
404
405    >>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'}
406    True
407
408    >>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'}
409    True
410
411    >>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'}
412    True
413
414    >>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'}
415    True
416
417    >>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'}
418    True
419
420    >>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'}
421    True
422
423    >>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'}
424    True
425
426    >>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'}
427    True
428
429    >>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'}
430    True
431
432    >>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'}
433    True
434    """
435    if not shorthand:
436        return {}
437
438    valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP)
439
440    units = dict(
441        field="(?P<field>.*)",
442        type="(?P<type>{})".format("|".join(valid_typecodes)),
443        agg_count="(?P<aggregate>count)",
444        op_count="(?P<op>count)",
445        aggregate="(?P<aggregate>{})".format("|".join(AGGREGATES)),
446        window_op="(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)),
447        timeUnit="(?P<timeUnit>{})".format("|".join(TIMEUNITS)),
448    )
449
450    patterns = []
451
452    if parse_aggregates:
453        patterns.extend([r"{agg_count}\(\)"])
454        patterns.extend([r"{aggregate}\({field}\)"])
455    if parse_window_ops:
456        patterns.extend([r"{op_count}\(\)"])
457        patterns.extend([r"{window_op}\({field}\)"])
458    if parse_timeunits:
459        patterns.extend([r"{timeUnit}\({field}\)"])
460
461    patterns.extend([r"{field}"])
462
463    if parse_types:
464        patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns)))
465
466    regexps = (
467        re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns
468    )
469
470    # find matches depending on valid fields passed
471    if isinstance(shorthand, dict):
472        attrs = shorthand
473    else:
474        attrs = next(
475            exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand)
476        )
477
478    # Handle short form of the type expression
479    if "type" in attrs:
480        attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"])
481
482    # counts are quantitative by default
483    if attrs == {"aggregate": "count"}:
484        attrs["type"] = "quantitative"
485
486    # times are temporal by default
487    if "timeUnit" in attrs and "type" not in attrs:
488        attrs["type"] = "temporal"
489
490    # if data is specified and type is not, infer type from data
491    if isinstance(data, pd.DataFrame) and "type" not in attrs:
492        if "field" in attrs and attrs["field"] in data.columns:
493            attrs["type"] = infer_vegalite_type(data[attrs["field"]])
494    return attrs
495
496
497def use_signature(Obj):
498    """Apply call signature and documentation of Obj to the decorated method"""
499
500    def decorate(f):
501        # call-signature of f is exposed via __wrapped__.
502        # we want it to mimic Obj.__init__
503        f.__wrapped__ = Obj.__init__
504        f._uses_signature = Obj
505
506        # Supplement the docstring of f with information from Obj
507        if Obj.__doc__:
508            doclines = Obj.__doc__.splitlines()
509            if f.__doc__:
510                doc = f.__doc__ + "\n".join(doclines[1:])
511            else:
512                doc = "\n".join(doclines)
513            try:
514                f.__doc__ = doc
515            except AttributeError:
516                # __doc__ is not modifiable for classes in Python < 3.3
517                pass
518
519        return f
520
521    return decorate
522
523
524def update_subtraits(obj, attrs, **kwargs):
525    """Recursively update sub-traits without overwriting other traits"""
526    # TODO: infer keywords from args
527    if not kwargs:
528        return obj
529
530    # obj can be a SchemaBase object or a dict
531    if obj is Undefined:
532        obj = dct = {}
533    elif isinstance(obj, SchemaBase):
534        dct = obj._kwds
535    else:
536        dct = obj
537
538    if isinstance(attrs, str):
539        attrs = (attrs,)
540
541    if len(attrs) == 0:
542        dct.update(kwargs)
543    else:
544        attr = attrs[0]
545        trait = dct.get(attr, Undefined)
546        if trait is Undefined:
547            trait = dct[attr] = {}
548        dct[attr] = update_subtraits(trait, attrs[1:], **kwargs)
549    return obj
550
551
552def update_nested(original, update, copy=False):
553    """Update nested dictionaries
554
555    Parameters
556    ----------
557    original : dict
558        the original (nested) dictionary, which will be updated in-place
559    update : dict
560        the nested dictionary of updates
561    copy : bool, default False
562        if True, then copy the original dictionary rather than modifying it
563
564    Returns
565    -------
566    original : dict
567        a reference to the (modified) original dict
568
569    Examples
570    --------
571    >>> original = {'x': {'b': 2, 'c': 4}}
572    >>> update = {'x': {'b': 5, 'd': 6}, 'y': 40}
573    >>> update_nested(original, update)  # doctest: +SKIP
574    {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40}
575    >>> original  # doctest: +SKIP
576    {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40}
577    """
578    if copy:
579        original = deepcopy(original)
580    for key, val in update.items():
581        if isinstance(val, Mapping):
582            orig_val = original.get(key, {})
583            if isinstance(orig_val, Mapping):
584                original[key] = update_nested(orig_val, val)
585            else:
586                original[key] = val
587        else:
588            original[key] = val
589    return original
590
591
592def display_traceback(in_ipython=True):
593    exc_info = sys.exc_info()
594
595    if in_ipython:
596        from IPython.core.getipython import get_ipython
597
598        ip = get_ipython()
599    else:
600        ip = None
601
602    if ip is not None:
603        ip.showtraceback(exc_info)
604    else:
605        traceback.print_exception(*exc_info)
606
607
608def infer_encoding_types(args, kwargs, channels):
609    """Infer typed keyword arguments for args and kwargs
610
611    Parameters
612    ----------
613    args : tuple
614        List of function args
615    kwargs : dict
616        Dict of function kwargs
617    channels : module
618        The module containing all altair encoding channel classes.
619
620    Returns
621    -------
622    kwargs : dict
623        All args and kwargs in a single dict, with keys and types
624        based on the channels mapping.
625    """
626    # Construct a dictionary of channel type to encoding name
627    # TODO: cache this somehow?
628    channel_objs = (getattr(channels, name) for name in dir(channels))
629    channel_objs = (
630        c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase)
631    )
632    channel_to_name = {c: c._encoding_name for c in channel_objs}
633    name_to_channel = {}
634    for chan, name in channel_to_name.items():
635        chans = name_to_channel.setdefault(name, {})
636        key = "value" if chan.__name__.endswith("Value") else "field"
637        chans[key] = chan
638
639    # First use the mapping to convert args to kwargs based on their types.
640    for arg in args:
641        if isinstance(arg, (list, tuple)) and len(arg) > 0:
642            type_ = type(arg[0])
643        else:
644            type_ = type(arg)
645
646        encoding = channel_to_name.get(type_, None)
647        if encoding is None:
648            raise NotImplementedError("positional of type {}" "".format(type_))
649        if encoding in kwargs:
650            raise ValueError("encoding {} specified twice.".format(encoding))
651        kwargs[encoding] = arg
652
653    def _wrap_in_channel_class(obj, encoding):
654        try:
655            condition = obj["condition"]
656        except (KeyError, TypeError):
657            pass
658        else:
659            if condition is not Undefined:
660                obj = obj.copy()
661                obj["condition"] = _wrap_in_channel_class(condition, encoding)
662
663        if isinstance(obj, SchemaBase):
664            return obj
665
666        if isinstance(obj, str):
667            obj = {"shorthand": obj}
668
669        if isinstance(obj, (list, tuple)):
670            return [_wrap_in_channel_class(subobj, encoding) for subobj in obj]
671
672        if encoding not in name_to_channel:
673            warnings.warn("Unrecognized encoding channel '{}'".format(encoding))
674            return obj
675
676        classes = name_to_channel[encoding]
677        cls = classes["value"] if "value" in obj else classes["field"]
678
679        try:
680            # Don't force validation here; some objects won't be valid until
681            # they're created in the context of a chart.
682            return cls.from_dict(obj, validate=False)
683        except jsonschema.ValidationError:
684            # our attempts at finding the correct class have failed
685            return obj
686
687    return {
688        encoding: _wrap_in_channel_class(obj, encoding)
689        for encoding, obj in kwargs.items()
690    }
691