1"""
2Expressions
3-----------
4
5Offer fast expression evaluation through numexpr
6
7"""
8import operator
9from typing import List, Set
10import warnings
11
12import numpy as np
13
14from pandas._config import get_option
15
16from pandas.core.dtypes.generic import ABCDataFrame
17
18from pandas.core.computation.check import NUMEXPR_INSTALLED
19from pandas.core.ops import roperator
20
21if NUMEXPR_INSTALLED:
22    import numexpr as ne
23
24_TEST_MODE = None
25_TEST_RESULT: List[bool] = []
26USE_NUMEXPR = NUMEXPR_INSTALLED
27_evaluate = None
28_where = None
29
30# the set of dtypes that we will allow pass to numexpr
31_ALLOWED_DTYPES = {
32    "evaluate": {"int64", "int32", "float64", "float32", "bool"},
33    "where": {"int64", "float64", "bool"},
34}
35
36# the minimum prod shape that we will use numexpr
37_MIN_ELEMENTS = 10000
38
39
40def set_use_numexpr(v=True):
41    # set/unset to use numexpr
42    global USE_NUMEXPR
43    if NUMEXPR_INSTALLED:
44        USE_NUMEXPR = v
45
46    # choose what we are going to do
47    global _evaluate, _where
48
49    _evaluate = _evaluate_numexpr if USE_NUMEXPR else _evaluate_standard
50    _where = _where_numexpr if USE_NUMEXPR else _where_standard
51
52
53def set_numexpr_threads(n=None):
54    # if we are using numexpr, set the threads to n
55    # otherwise reset
56    if NUMEXPR_INSTALLED and USE_NUMEXPR:
57        if n is None:
58            n = ne.detect_number_of_cores()
59        ne.set_num_threads(n)
60
61
62def _evaluate_standard(op, op_str, a, b):
63    """
64    Standard evaluation.
65    """
66    if _TEST_MODE:
67        _store_test_result(False)
68    with np.errstate(all="ignore"):
69        return op(a, b)
70
71
72def _can_use_numexpr(op, op_str, a, b, dtype_check):
73    """ return a boolean if we WILL be using numexpr """
74    if op_str is not None:
75
76        # required min elements (otherwise we are adding overhead)
77        if np.prod(a.shape) > _MIN_ELEMENTS:
78            # check for dtype compatibility
79            dtypes: Set[str] = set()
80            for o in [a, b]:
81                # Series implements dtypes, check for dimension count as well
82                if hasattr(o, "dtypes") and o.ndim > 1:
83                    s = o.dtypes.value_counts()
84                    if len(s) > 1:
85                        return False
86                    dtypes |= set(s.index.astype(str))
87                # ndarray and Series Case
88                elif hasattr(o, "dtype"):
89                    dtypes |= {o.dtype.name}
90
91            # allowed are a superset
92            if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
93                return True
94
95    return False
96
97
98def _evaluate_numexpr(op, op_str, a, b):
99    result = None
100
101    if _can_use_numexpr(op, op_str, a, b, "evaluate"):
102        is_reversed = op.__name__.strip("_").startswith("r")
103        if is_reversed:
104            # we were originally called by a reversed op method
105            a, b = b, a
106
107        a_value = a
108        b_value = b
109
110        result = ne.evaluate(
111            f"a_value {op_str} b_value",
112            local_dict={"a_value": a_value, "b_value": b_value},
113            casting="safe",
114        )
115
116    if _TEST_MODE:
117        _store_test_result(result is not None)
118
119    if result is None:
120        result = _evaluate_standard(op, op_str, a, b)
121
122    return result
123
124
125_op_str_mapping = {
126    operator.add: "+",
127    roperator.radd: "+",
128    operator.mul: "*",
129    roperator.rmul: "*",
130    operator.sub: "-",
131    roperator.rsub: "-",
132    operator.truediv: "/",
133    roperator.rtruediv: "/",
134    operator.floordiv: "//",
135    roperator.rfloordiv: "//",
136    # we require Python semantics for mod of negative for backwards compatibility
137    # see https://github.com/pydata/numexpr/issues/365
138    # so sticking with unaccelerated for now
139    operator.mod: None,
140    roperator.rmod: "%",
141    operator.pow: "**",
142    roperator.rpow: "**",
143    operator.eq: "==",
144    operator.ne: "!=",
145    operator.le: "<=",
146    operator.lt: "<",
147    operator.ge: ">=",
148    operator.gt: ">",
149    operator.and_: "&",
150    roperator.rand_: "&",
151    operator.or_: "|",
152    roperator.ror_: "|",
153    operator.xor: "^",
154    roperator.rxor: "^",
155    divmod: None,
156    roperator.rdivmod: None,
157}
158
159
160def _where_standard(cond, a, b):
161    # Caller is responsible for extracting ndarray if necessary
162    return np.where(cond, a, b)
163
164
165def _where_numexpr(cond, a, b):
166    # Caller is responsible for extracting ndarray if necessary
167    result = None
168
169    if _can_use_numexpr(None, "where", a, b, "where"):
170
171        result = ne.evaluate(
172            "where(cond_value, a_value, b_value)",
173            local_dict={"cond_value": cond, "a_value": a, "b_value": b},
174            casting="safe",
175        )
176
177    if result is None:
178        result = _where_standard(cond, a, b)
179
180    return result
181
182
183# turn myself on
184set_use_numexpr(get_option("compute.use_numexpr"))
185
186
187def _has_bool_dtype(x):
188    if isinstance(x, ABCDataFrame):
189        return "bool" in x.dtypes
190    try:
191        return x.dtype == bool
192    except AttributeError:
193        return isinstance(x, (bool, np.bool_))
194
195
196def _bool_arith_check(
197    op_str, a, b, not_allowed=frozenset(("/", "//", "**")), unsupported=None
198):
199    if unsupported is None:
200        unsupported = {"+": "|", "*": "&", "-": "^"}
201
202    if _has_bool_dtype(a) and _has_bool_dtype(b):
203        if op_str in unsupported:
204            warnings.warn(
205                f"evaluating in Python space because the {repr(op_str)} "
206                "operator is not supported by numexpr for "
207                f"the bool dtype, use {repr(unsupported[op_str])} instead"
208            )
209            return False
210
211        if op_str in not_allowed:
212            raise NotImplementedError(
213                f"operator {repr(op_str)} not implemented for bool dtypes"
214            )
215    return True
216
217
218def evaluate(op, a, b, use_numexpr: bool = True):
219    """
220    Evaluate and return the expression of the op on a and b.
221
222    Parameters
223    ----------
224    op : the actual operand
225    a : left operand
226    b : right operand
227    use_numexpr : bool, default True
228        Whether to try to use numexpr.
229    """
230    op_str = _op_str_mapping[op]
231    if op_str is not None:
232        use_numexpr = use_numexpr and _bool_arith_check(op_str, a, b)
233        if use_numexpr:
234            # error: "None" not callable
235            return _evaluate(op, op_str, a, b)  # type: ignore[misc]
236    return _evaluate_standard(op, op_str, a, b)
237
238
239def where(cond, a, b, use_numexpr=True):
240    """
241    Evaluate the where condition cond on a and b.
242
243    Parameters
244    ----------
245    cond : np.ndarray[bool]
246    a : return if cond is True
247    b : return if cond is False
248    use_numexpr : bool, default True
249        Whether to try to use numexpr.
250    """
251    assert _where is not None
252    return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b)
253
254
255def set_test_mode(v: bool = True) -> None:
256    """
257    Keeps track of whether numexpr was used.
258
259    Stores an additional ``True`` for every successful use of evaluate with
260    numexpr since the last ``get_test_result``.
261    """
262    global _TEST_MODE, _TEST_RESULT
263    _TEST_MODE = v
264    _TEST_RESULT = []
265
266
267def _store_test_result(used_numexpr: bool) -> None:
268    global _TEST_RESULT
269    if used_numexpr:
270        _TEST_RESULT.append(used_numexpr)
271
272
273def get_test_result() -> List[bool]:
274    """
275    Get test result and reset test_results.
276    """
277    global _TEST_RESULT
278    res = _TEST_RESULT
279    _TEST_RESULT = []
280    return res
281