1from typing import NamedTuple
2
3import numpy as np
4from . import is_scalar_nan
5
6
7def _unique(values, *, return_inverse=False):
8    """Helper function to find unique values with support for python objects.
9
10    Uses pure python method for object dtype, and numpy method for
11    all other dtypes.
12
13    Parameters
14    ----------
15    values : ndarray
16        Values to check for unknowns.
17
18    return_inverse : bool, default=False
19        If True, also return the indices of the unique values.
20
21    Returns
22    -------
23    unique : ndarray
24        The sorted unique values.
25
26    unique_inverse : ndarray
27        The indices to reconstruct the original array from the unique array.
28        Only provided if `return_inverse` is True.
29    """
30    if values.dtype == object:
31        return _unique_python(values, return_inverse=return_inverse)
32    # numerical
33    out = np.unique(values, return_inverse=return_inverse)
34
35    if return_inverse:
36        uniques, inverse = out
37    else:
38        uniques = out
39
40    # np.unique will have duplicate missing values at the end of `uniques`
41    # here we clip the nans and remove it from uniques
42    if uniques.size and is_scalar_nan(uniques[-1]):
43        nan_idx = np.searchsorted(uniques, np.nan)
44        uniques = uniques[: nan_idx + 1]
45        if return_inverse:
46            inverse[inverse > nan_idx] = nan_idx
47
48    if return_inverse:
49        return uniques, inverse
50    return uniques
51
52
53class MissingValues(NamedTuple):
54    """Data class for missing data information"""
55
56    nan: bool
57    none: bool
58
59    def to_list(self):
60        """Convert tuple to a list where None is always first."""
61        output = []
62        if self.none:
63            output.append(None)
64        if self.nan:
65            output.append(np.nan)
66        return output
67
68
69def _extract_missing(values):
70    """Extract missing values from `values`.
71
72    Parameters
73    ----------
74    values: set
75        Set of values to extract missing from.
76
77    Returns
78    -------
79    output: set
80        Set with missing values extracted.
81
82    missing_values: MissingValues
83        Object with missing value information.
84    """
85    missing_values_set = {
86        value for value in values if value is None or is_scalar_nan(value)
87    }
88
89    if not missing_values_set:
90        return values, MissingValues(nan=False, none=False)
91
92    if None in missing_values_set:
93        if len(missing_values_set) == 1:
94            output_missing_values = MissingValues(nan=False, none=True)
95        else:
96            # If there is more than one missing value, then it has to be
97            # float('nan') or np.nan
98            output_missing_values = MissingValues(nan=True, none=True)
99    else:
100        output_missing_values = MissingValues(nan=True, none=False)
101
102    # create set without the missing values
103    output = values - missing_values_set
104    return output, output_missing_values
105
106
107class _nandict(dict):
108    """Dictionary with support for nans."""
109
110    def __init__(self, mapping):
111        super().__init__(mapping)
112        for key, value in mapping.items():
113            if is_scalar_nan(key):
114                self.nan_value = value
115                break
116
117    def __missing__(self, key):
118        if hasattr(self, "nan_value") and is_scalar_nan(key):
119            return self.nan_value
120        raise KeyError(key)
121
122
123def _map_to_integer(values, uniques):
124    """Map values based on its position in uniques."""
125    table = _nandict({val: i for i, val in enumerate(uniques)})
126    return np.array([table[v] for v in values])
127
128
129def _unique_python(values, *, return_inverse):
130    # Only used in `_uniques`, see docstring there for details
131    try:
132        uniques_set = set(values)
133        uniques_set, missing_values = _extract_missing(uniques_set)
134
135        uniques = sorted(uniques_set)
136        uniques.extend(missing_values.to_list())
137        uniques = np.array(uniques, dtype=values.dtype)
138    except TypeError:
139        types = sorted(t.__qualname__ for t in set(type(v) for v in values))
140        raise TypeError(
141            "Encoders require their input to be uniformly "
142            f"strings or numbers. Got {types}"
143        )
144
145    if return_inverse:
146        return uniques, _map_to_integer(values, uniques)
147
148    return uniques
149
150
151def _encode(values, *, uniques, check_unknown=True):
152    """Helper function to encode values into [0, n_uniques - 1].
153
154    Uses pure python method for object dtype, and numpy method for
155    all other dtypes.
156    The numpy method has the limitation that the `uniques` need to
157    be sorted. Importantly, this is not checked but assumed to already be
158    the case. The calling method needs to ensure this for all non-object
159    values.
160
161    Parameters
162    ----------
163    values : ndarray
164        Values to encode.
165    uniques : ndarray
166        The unique values in `values`. If the dtype is not object, then
167        `uniques` needs to be sorted.
168    check_unknown : bool, default=True
169        If True, check for values in `values` that are not in `unique`
170        and raise an error. This is ignored for object dtype, and treated as
171        True in this case. This parameter is useful for
172        _BaseEncoder._transform() to avoid calling _check_unknown()
173        twice.
174
175    Returns
176    -------
177    encoded : ndarray
178        Encoded values
179    """
180    if values.dtype.kind in "OUS":
181        try:
182            return _map_to_integer(values, uniques)
183        except KeyError as e:
184            raise ValueError(f"y contains previously unseen labels: {str(e)}")
185    else:
186        if check_unknown:
187            diff = _check_unknown(values, uniques)
188            if diff:
189                raise ValueError(f"y contains previously unseen labels: {str(diff)}")
190        return np.searchsorted(uniques, values)
191
192
193def _check_unknown(values, known_values, return_mask=False):
194    """
195    Helper function to check for unknowns in values to be encoded.
196
197    Uses pure python method for object dtype, and numpy method for
198    all other dtypes.
199
200    Parameters
201    ----------
202    values : array
203        Values to check for unknowns.
204    known_values : array
205        Known values. Must be unique.
206    return_mask : bool, default=False
207        If True, return a mask of the same shape as `values` indicating
208        the valid values.
209
210    Returns
211    -------
212    diff : list
213        The unique values present in `values` and not in `know_values`.
214    valid_mask : boolean array
215        Additionally returned if ``return_mask=True``.
216
217    """
218    valid_mask = None
219
220    if values.dtype.kind in "OUS":
221        values_set = set(values)
222        values_set, missing_in_values = _extract_missing(values_set)
223
224        uniques_set = set(known_values)
225        uniques_set, missing_in_uniques = _extract_missing(uniques_set)
226        diff = values_set - uniques_set
227
228        nan_in_diff = missing_in_values.nan and not missing_in_uniques.nan
229        none_in_diff = missing_in_values.none and not missing_in_uniques.none
230
231        def is_valid(value):
232            return (
233                value in uniques_set
234                or missing_in_uniques.none
235                and value is None
236                or missing_in_uniques.nan
237                and is_scalar_nan(value)
238            )
239
240        if return_mask:
241            if diff or nan_in_diff or none_in_diff:
242                valid_mask = np.array([is_valid(value) for value in values])
243            else:
244                valid_mask = np.ones(len(values), dtype=bool)
245
246        diff = list(diff)
247        if none_in_diff:
248            diff.append(None)
249        if nan_in_diff:
250            diff.append(np.nan)
251    else:
252        unique_values = np.unique(values)
253        diff = np.setdiff1d(unique_values, known_values, assume_unique=True)
254        if return_mask:
255            if diff.size:
256                valid_mask = np.in1d(values, known_values)
257            else:
258                valid_mask = np.ones(len(values), dtype=bool)
259
260        # check for nans in the known_values
261        if np.isnan(known_values).any():
262            diff_is_nan = np.isnan(diff)
263            if diff_is_nan.any():
264                # removes nan from valid_mask
265                if diff.size and return_mask:
266                    is_nan = np.isnan(values)
267                    valid_mask[is_nan] = 1
268
269                # remove nan from diff
270                diff = diff[~diff_is_nan]
271        diff = list(diff)
272
273    if return_mask:
274        return diff, valid_mask
275    return diff
276