1from typing import NamedTuple 2 3import numpy as np 4from . import is_scalar_nan 5 6 7def _unique(values, *, return_inverse=False): 8 """Helper function to find unique values with support for python objects. 9 10 Uses pure python method for object dtype, and numpy method for 11 all other dtypes. 12 13 Parameters 14 ---------- 15 values : ndarray 16 Values to check for unknowns. 17 18 return_inverse : bool, default=False 19 If True, also return the indices of the unique values. 20 21 Returns 22 ------- 23 unique : ndarray 24 The sorted unique values. 25 26 unique_inverse : ndarray 27 The indices to reconstruct the original array from the unique array. 28 Only provided if `return_inverse` is True. 29 """ 30 if values.dtype == object: 31 return _unique_python(values, return_inverse=return_inverse) 32 # numerical 33 out = np.unique(values, return_inverse=return_inverse) 34 35 if return_inverse: 36 uniques, inverse = out 37 else: 38 uniques = out 39 40 # np.unique will have duplicate missing values at the end of `uniques` 41 # here we clip the nans and remove it from uniques 42 if uniques.size and is_scalar_nan(uniques[-1]): 43 nan_idx = np.searchsorted(uniques, np.nan) 44 uniques = uniques[: nan_idx + 1] 45 if return_inverse: 46 inverse[inverse > nan_idx] = nan_idx 47 48 if return_inverse: 49 return uniques, inverse 50 return uniques 51 52 53class MissingValues(NamedTuple): 54 """Data class for missing data information""" 55 56 nan: bool 57 none: bool 58 59 def to_list(self): 60 """Convert tuple to a list where None is always first.""" 61 output = [] 62 if self.none: 63 output.append(None) 64 if self.nan: 65 output.append(np.nan) 66 return output 67 68 69def _extract_missing(values): 70 """Extract missing values from `values`. 71 72 Parameters 73 ---------- 74 values: set 75 Set of values to extract missing from. 76 77 Returns 78 ------- 79 output: set 80 Set with missing values extracted. 81 82 missing_values: MissingValues 83 Object with missing value information. 84 """ 85 missing_values_set = { 86 value for value in values if value is None or is_scalar_nan(value) 87 } 88 89 if not missing_values_set: 90 return values, MissingValues(nan=False, none=False) 91 92 if None in missing_values_set: 93 if len(missing_values_set) == 1: 94 output_missing_values = MissingValues(nan=False, none=True) 95 else: 96 # If there is more than one missing value, then it has to be 97 # float('nan') or np.nan 98 output_missing_values = MissingValues(nan=True, none=True) 99 else: 100 output_missing_values = MissingValues(nan=True, none=False) 101 102 # create set without the missing values 103 output = values - missing_values_set 104 return output, output_missing_values 105 106 107class _nandict(dict): 108 """Dictionary with support for nans.""" 109 110 def __init__(self, mapping): 111 super().__init__(mapping) 112 for key, value in mapping.items(): 113 if is_scalar_nan(key): 114 self.nan_value = value 115 break 116 117 def __missing__(self, key): 118 if hasattr(self, "nan_value") and is_scalar_nan(key): 119 return self.nan_value 120 raise KeyError(key) 121 122 123def _map_to_integer(values, uniques): 124 """Map values based on its position in uniques.""" 125 table = _nandict({val: i for i, val in enumerate(uniques)}) 126 return np.array([table[v] for v in values]) 127 128 129def _unique_python(values, *, return_inverse): 130 # Only used in `_uniques`, see docstring there for details 131 try: 132 uniques_set = set(values) 133 uniques_set, missing_values = _extract_missing(uniques_set) 134 135 uniques = sorted(uniques_set) 136 uniques.extend(missing_values.to_list()) 137 uniques = np.array(uniques, dtype=values.dtype) 138 except TypeError: 139 types = sorted(t.__qualname__ for t in set(type(v) for v in values)) 140 raise TypeError( 141 "Encoders require their input to be uniformly " 142 f"strings or numbers. Got {types}" 143 ) 144 145 if return_inverse: 146 return uniques, _map_to_integer(values, uniques) 147 148 return uniques 149 150 151def _encode(values, *, uniques, check_unknown=True): 152 """Helper function to encode values into [0, n_uniques - 1]. 153 154 Uses pure python method for object dtype, and numpy method for 155 all other dtypes. 156 The numpy method has the limitation that the `uniques` need to 157 be sorted. Importantly, this is not checked but assumed to already be 158 the case. The calling method needs to ensure this for all non-object 159 values. 160 161 Parameters 162 ---------- 163 values : ndarray 164 Values to encode. 165 uniques : ndarray 166 The unique values in `values`. If the dtype is not object, then 167 `uniques` needs to be sorted. 168 check_unknown : bool, default=True 169 If True, check for values in `values` that are not in `unique` 170 and raise an error. This is ignored for object dtype, and treated as 171 True in this case. This parameter is useful for 172 _BaseEncoder._transform() to avoid calling _check_unknown() 173 twice. 174 175 Returns 176 ------- 177 encoded : ndarray 178 Encoded values 179 """ 180 if values.dtype.kind in "OUS": 181 try: 182 return _map_to_integer(values, uniques) 183 except KeyError as e: 184 raise ValueError(f"y contains previously unseen labels: {str(e)}") 185 else: 186 if check_unknown: 187 diff = _check_unknown(values, uniques) 188 if diff: 189 raise ValueError(f"y contains previously unseen labels: {str(diff)}") 190 return np.searchsorted(uniques, values) 191 192 193def _check_unknown(values, known_values, return_mask=False): 194 """ 195 Helper function to check for unknowns in values to be encoded. 196 197 Uses pure python method for object dtype, and numpy method for 198 all other dtypes. 199 200 Parameters 201 ---------- 202 values : array 203 Values to check for unknowns. 204 known_values : array 205 Known values. Must be unique. 206 return_mask : bool, default=False 207 If True, return a mask of the same shape as `values` indicating 208 the valid values. 209 210 Returns 211 ------- 212 diff : list 213 The unique values present in `values` and not in `know_values`. 214 valid_mask : boolean array 215 Additionally returned if ``return_mask=True``. 216 217 """ 218 valid_mask = None 219 220 if values.dtype.kind in "OUS": 221 values_set = set(values) 222 values_set, missing_in_values = _extract_missing(values_set) 223 224 uniques_set = set(known_values) 225 uniques_set, missing_in_uniques = _extract_missing(uniques_set) 226 diff = values_set - uniques_set 227 228 nan_in_diff = missing_in_values.nan and not missing_in_uniques.nan 229 none_in_diff = missing_in_values.none and not missing_in_uniques.none 230 231 def is_valid(value): 232 return ( 233 value in uniques_set 234 or missing_in_uniques.none 235 and value is None 236 or missing_in_uniques.nan 237 and is_scalar_nan(value) 238 ) 239 240 if return_mask: 241 if diff or nan_in_diff or none_in_diff: 242 valid_mask = np.array([is_valid(value) for value in values]) 243 else: 244 valid_mask = np.ones(len(values), dtype=bool) 245 246 diff = list(diff) 247 if none_in_diff: 248 diff.append(None) 249 if nan_in_diff: 250 diff.append(np.nan) 251 else: 252 unique_values = np.unique(values) 253 diff = np.setdiff1d(unique_values, known_values, assume_unique=True) 254 if return_mask: 255 if diff.size: 256 valid_mask = np.in1d(values, known_values) 257 else: 258 valid_mask = np.ones(len(values), dtype=bool) 259 260 # check for nans in the known_values 261 if np.isnan(known_values).any(): 262 diff_is_nan = np.isnan(diff) 263 if diff_is_nan.any(): 264 # removes nan from valid_mask 265 if diff.size and return_mask: 266 is_nan = np.isnan(values) 267 valid_mask[is_nan] = 1 268 269 # remove nan from diff 270 diff = diff[~diff_is_nan] 271 diff = list(diff) 272 273 if return_mask: 274 return diff, valid_mask 275 return diff 276