1""" 2Expressions 3----------- 4 5Offer fast expression evaluation through numexpr 6 7""" 8import operator 9from typing import List, Set 10import warnings 11 12import numpy as np 13 14from pandas._config import get_option 15 16from pandas.core.dtypes.generic import ABCDataFrame 17 18from pandas.core.computation.check import NUMEXPR_INSTALLED 19from pandas.core.ops import roperator 20 21if NUMEXPR_INSTALLED: 22 import numexpr as ne 23 24_TEST_MODE = None 25_TEST_RESULT: List[bool] = [] 26USE_NUMEXPR = NUMEXPR_INSTALLED 27_evaluate = None 28_where = None 29 30# the set of dtypes that we will allow pass to numexpr 31_ALLOWED_DTYPES = { 32 "evaluate": {"int64", "int32", "float64", "float32", "bool"}, 33 "where": {"int64", "float64", "bool"}, 34} 35 36# the minimum prod shape that we will use numexpr 37_MIN_ELEMENTS = 10000 38 39 40def set_use_numexpr(v=True): 41 # set/unset to use numexpr 42 global USE_NUMEXPR 43 if NUMEXPR_INSTALLED: 44 USE_NUMEXPR = v 45 46 # choose what we are going to do 47 global _evaluate, _where 48 49 _evaluate = _evaluate_numexpr if USE_NUMEXPR else _evaluate_standard 50 _where = _where_numexpr if USE_NUMEXPR else _where_standard 51 52 53def set_numexpr_threads(n=None): 54 # if we are using numexpr, set the threads to n 55 # otherwise reset 56 if NUMEXPR_INSTALLED and USE_NUMEXPR: 57 if n is None: 58 n = ne.detect_number_of_cores() 59 ne.set_num_threads(n) 60 61 62def _evaluate_standard(op, op_str, a, b): 63 """ 64 Standard evaluation. 65 """ 66 if _TEST_MODE: 67 _store_test_result(False) 68 with np.errstate(all="ignore"): 69 return op(a, b) 70 71 72def _can_use_numexpr(op, op_str, a, b, dtype_check): 73 """ return a boolean if we WILL be using numexpr """ 74 if op_str is not None: 75 76 # required min elements (otherwise we are adding overhead) 77 if np.prod(a.shape) > _MIN_ELEMENTS: 78 # check for dtype compatibility 79 dtypes: Set[str] = set() 80 for o in [a, b]: 81 # Series implements dtypes, check for dimension count as well 82 if hasattr(o, "dtypes") and o.ndim > 1: 83 s = o.dtypes.value_counts() 84 if len(s) > 1: 85 return False 86 dtypes |= set(s.index.astype(str)) 87 # ndarray and Series Case 88 elif hasattr(o, "dtype"): 89 dtypes |= {o.dtype.name} 90 91 # allowed are a superset 92 if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes: 93 return True 94 95 return False 96 97 98def _evaluate_numexpr(op, op_str, a, b): 99 result = None 100 101 if _can_use_numexpr(op, op_str, a, b, "evaluate"): 102 is_reversed = op.__name__.strip("_").startswith("r") 103 if is_reversed: 104 # we were originally called by a reversed op method 105 a, b = b, a 106 107 a_value = a 108 b_value = b 109 110 result = ne.evaluate( 111 f"a_value {op_str} b_value", 112 local_dict={"a_value": a_value, "b_value": b_value}, 113 casting="safe", 114 ) 115 116 if _TEST_MODE: 117 _store_test_result(result is not None) 118 119 if result is None: 120 result = _evaluate_standard(op, op_str, a, b) 121 122 return result 123 124 125_op_str_mapping = { 126 operator.add: "+", 127 roperator.radd: "+", 128 operator.mul: "*", 129 roperator.rmul: "*", 130 operator.sub: "-", 131 roperator.rsub: "-", 132 operator.truediv: "/", 133 roperator.rtruediv: "/", 134 operator.floordiv: "//", 135 roperator.rfloordiv: "//", 136 # we require Python semantics for mod of negative for backwards compatibility 137 # see https://github.com/pydata/numexpr/issues/365 138 # so sticking with unaccelerated for now 139 operator.mod: None, 140 roperator.rmod: "%", 141 operator.pow: "**", 142 roperator.rpow: "**", 143 operator.eq: "==", 144 operator.ne: "!=", 145 operator.le: "<=", 146 operator.lt: "<", 147 operator.ge: ">=", 148 operator.gt: ">", 149 operator.and_: "&", 150 roperator.rand_: "&", 151 operator.or_: "|", 152 roperator.ror_: "|", 153 operator.xor: "^", 154 roperator.rxor: "^", 155 divmod: None, 156 roperator.rdivmod: None, 157} 158 159 160def _where_standard(cond, a, b): 161 # Caller is responsible for extracting ndarray if necessary 162 return np.where(cond, a, b) 163 164 165def _where_numexpr(cond, a, b): 166 # Caller is responsible for extracting ndarray if necessary 167 result = None 168 169 if _can_use_numexpr(None, "where", a, b, "where"): 170 171 result = ne.evaluate( 172 "where(cond_value, a_value, b_value)", 173 local_dict={"cond_value": cond, "a_value": a, "b_value": b}, 174 casting="safe", 175 ) 176 177 if result is None: 178 result = _where_standard(cond, a, b) 179 180 return result 181 182 183# turn myself on 184set_use_numexpr(get_option("compute.use_numexpr")) 185 186 187def _has_bool_dtype(x): 188 if isinstance(x, ABCDataFrame): 189 return "bool" in x.dtypes 190 try: 191 return x.dtype == bool 192 except AttributeError: 193 return isinstance(x, (bool, np.bool_)) 194 195 196def _bool_arith_check( 197 op_str, a, b, not_allowed=frozenset(("/", "//", "**")), unsupported=None 198): 199 if unsupported is None: 200 unsupported = {"+": "|", "*": "&", "-": "^"} 201 202 if _has_bool_dtype(a) and _has_bool_dtype(b): 203 if op_str in unsupported: 204 warnings.warn( 205 f"evaluating in Python space because the {repr(op_str)} " 206 "operator is not supported by numexpr for " 207 f"the bool dtype, use {repr(unsupported[op_str])} instead" 208 ) 209 return False 210 211 if op_str in not_allowed: 212 raise NotImplementedError( 213 f"operator {repr(op_str)} not implemented for bool dtypes" 214 ) 215 return True 216 217 218def evaluate(op, a, b, use_numexpr: bool = True): 219 """ 220 Evaluate and return the expression of the op on a and b. 221 222 Parameters 223 ---------- 224 op : the actual operand 225 a : left operand 226 b : right operand 227 use_numexpr : bool, default True 228 Whether to try to use numexpr. 229 """ 230 op_str = _op_str_mapping[op] 231 if op_str is not None: 232 use_numexpr = use_numexpr and _bool_arith_check(op_str, a, b) 233 if use_numexpr: 234 # error: "None" not callable 235 return _evaluate(op, op_str, a, b) # type: ignore[misc] 236 return _evaluate_standard(op, op_str, a, b) 237 238 239def where(cond, a, b, use_numexpr=True): 240 """ 241 Evaluate the where condition cond on a and b. 242 243 Parameters 244 ---------- 245 cond : np.ndarray[bool] 246 a : return if cond is True 247 b : return if cond is False 248 use_numexpr : bool, default True 249 Whether to try to use numexpr. 250 """ 251 assert _where is not None 252 return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b) 253 254 255def set_test_mode(v: bool = True) -> None: 256 """ 257 Keeps track of whether numexpr was used. 258 259 Stores an additional ``True`` for every successful use of evaluate with 260 numexpr since the last ``get_test_result``. 261 """ 262 global _TEST_MODE, _TEST_RESULT 263 _TEST_MODE = v 264 _TEST_RESULT = [] 265 266 267def _store_test_result(used_numexpr: bool) -> None: 268 global _TEST_RESULT 269 if used_numexpr: 270 _TEST_RESULT.append(used_numexpr) 271 272 273def get_test_result() -> List[bool]: 274 """ 275 Get test result and reset test_results. 276 """ 277 global _TEST_RESULT 278 res = _TEST_RESULT 279 _TEST_RESULT = [] 280 return res 281