1"""
2Core eval alignment algorithms.
3"""
4from __future__ import annotations
5
6from functools import partial, wraps
7from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Type, Union
8import warnings
9
10import numpy as np
11
12from pandas._typing import FrameOrSeries
13from pandas.errors import PerformanceWarning
14
15from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
16
17from pandas.core.base import PandasObject
18import pandas.core.common as com
19from pandas.core.computation.common import result_type_many
20
21if TYPE_CHECKING:
22    from pandas.core.indexes.api import Index
23
24
25def _align_core_single_unary_op(
26    term,
27) -> Tuple[Union[partial, Type[FrameOrSeries]], Optional[Dict[str, Index]]]:
28
29    typ: Union[partial, Type[FrameOrSeries]]
30    axes: Optional[Dict[str, Index]] = None
31
32    if isinstance(term.value, np.ndarray):
33        typ = partial(np.asanyarray, dtype=term.value.dtype)
34    else:
35        typ = type(term.value)
36        if hasattr(term.value, "axes"):
37            axes = _zip_axes_from_type(typ, term.value.axes)
38
39    return typ, axes
40
41
42def _zip_axes_from_type(
43    typ: Type[FrameOrSeries], new_axes: Sequence[Index]
44) -> Dict[str, Index]:
45    return {name: new_axes[i] for i, name in enumerate(typ._AXIS_ORDERS)}
46
47
48def _any_pandas_objects(terms) -> bool:
49    """
50    Check a sequence of terms for instances of PandasObject.
51    """
52    return any(isinstance(term.value, PandasObject) for term in terms)
53
54
55def _filter_special_cases(f):
56    @wraps(f)
57    def wrapper(terms):
58        # single unary operand
59        if len(terms) == 1:
60            return _align_core_single_unary_op(terms[0])
61
62        term_values = (term.value for term in terms)
63
64        # we don't have any pandas objects
65        if not _any_pandas_objects(terms):
66            return result_type_many(*term_values), None
67
68        return f(terms)
69
70    return wrapper
71
72
73@_filter_special_cases
74def _align_core(terms):
75    term_index = [i for i, term in enumerate(terms) if hasattr(term.value, "axes")]
76    term_dims = [terms[i].value.ndim for i in term_index]
77
78    from pandas import Series
79
80    ndims = Series(dict(zip(term_index, term_dims)))
81
82    # initial axes are the axes of the largest-axis'd term
83    biggest = terms[ndims.idxmax()].value
84    typ = biggest._constructor
85    axes = biggest.axes
86    naxes = len(axes)
87    gt_than_one_axis = naxes > 1
88
89    for value in (terms[i].value for i in term_index):
90        is_series = isinstance(value, ABCSeries)
91        is_series_and_gt_one_axis = is_series and gt_than_one_axis
92
93        for axis, items in enumerate(value.axes):
94            if is_series_and_gt_one_axis:
95                ax, itm = naxes - 1, value.index
96            else:
97                ax, itm = axis, items
98
99            if not axes[ax].is_(itm):
100                axes[ax] = axes[ax].join(itm, how="outer")
101
102    for i, ndim in ndims.items():
103        for axis, items in zip(range(ndim), axes):
104            ti = terms[i].value
105
106            if hasattr(ti, "reindex"):
107                transpose = isinstance(ti, ABCSeries) and naxes > 1
108                reindexer = axes[naxes - 1] if transpose else items
109
110                term_axis_size = len(ti.axes[axis])
111                reindexer_size = len(reindexer)
112
113                ordm = np.log10(max(1, abs(reindexer_size - term_axis_size)))
114                if ordm >= 1 and reindexer_size >= 10000:
115                    w = (
116                        f"Alignment difference on axis {axis} is larger "
117                        f"than an order of magnitude on term {repr(terms[i].name)}, "
118                        f"by more than {ordm:.4g}; performance may suffer"
119                    )
120                    warnings.warn(w, category=PerformanceWarning, stacklevel=6)
121
122                f = partial(ti.reindex, reindexer, axis=axis, copy=False)
123
124                terms[i].update(f())
125
126        terms[i].update(terms[i].value.values)
127
128    return typ, _zip_axes_from_type(typ, axes)
129
130
131def align_terms(terms):
132    """
133    Align a set of terms.
134    """
135    try:
136        # flatten the parse tree (a nested list, really)
137        terms = list(com.flatten(terms))
138    except TypeError:
139        # can't iterate so it must just be a constant or single variable
140        if isinstance(terms.value, (ABCSeries, ABCDataFrame)):
141            typ = type(terms.value)
142            return typ, _zip_axes_from_type(typ, terms.value.axes)
143        return np.result_type(terms.type), None
144
145    # if all resolved variables are numeric scalars
146    if all(term.is_scalar for term in terms):
147        return result_type_many(*(term.value for term in terms)).type, None
148
149    # perform the main alignment
150    typ, axes = _align_core(terms)
151    return typ, axes
152
153
154def reconstruct_object(typ, obj, axes, dtype):
155    """
156    Reconstruct an object given its type, raw value, and possibly empty
157    (None) axes.
158
159    Parameters
160    ----------
161    typ : object
162        A type
163    obj : object
164        The value to use in the type constructor
165    axes : dict
166        The axes to use to construct the resulting pandas object
167
168    Returns
169    -------
170    ret : typ
171        An object of type ``typ`` with the value `obj` and possible axes
172        `axes`.
173    """
174    try:
175        typ = typ.type
176    except AttributeError:
177        pass
178
179    res_t = np.result_type(obj.dtype, dtype)
180
181    if not isinstance(typ, partial) and issubclass(typ, PandasObject):
182        return typ(obj, dtype=res_t, **axes)
183
184    # special case for pathological things like ~True/~False
185    if hasattr(res_t, "type") and typ == np.bool_ and res_t != np.bool_:
186        ret_value = res_t.type(obj)
187    else:
188        ret_value = typ(obj).astype(res_t)
189        # The condition is to distinguish 0-dim array (returned in case of
190        # scalar) and 1 element array
191        # e.g. np.array(0) and np.array([0])
192        if (
193            len(obj.shape) == 1
194            and len(obj) == 1
195            and not isinstance(ret_value, np.ndarray)
196        ):
197            ret_value = np.array([ret_value]).astype(res_t)
198
199    return ret_value
200