1""" 2This file contains a minimal set of tests for compliance with the extension 3array interface test suite, and should contain no other tests. 4The test suite for the full functionality of the array is located in 5`pandas/tests/arrays/`. 6 7The tests in this file are inherited from the BaseExtensionTests, and only 8minimal tweaks should be applied to get the tests passing (by overwriting a 9parent method). 10 11Additional tests should either be added to one of the BaseExtensionTests 12classes (if they are relevant for the extension interface for all dtypes), or 13be added to the array-specific tests in `pandas/tests/arrays/`. 14 15""" 16import numpy as np 17import pytest 18 19import pandas as pd 20import pandas._testing as tm 21from pandas.core.arrays.boolean import BooleanDtype 22from pandas.tests.extension import base 23 24 25def make_data(): 26 return [True, False] * 4 + [np.nan] + [True, False] * 44 + [np.nan] + [True, False] 27 28 29@pytest.fixture 30def dtype(): 31 return BooleanDtype() 32 33 34@pytest.fixture 35def data(dtype): 36 return pd.array(make_data(), dtype=dtype) 37 38 39@pytest.fixture 40def data_for_twos(dtype): 41 return pd.array(np.ones(100), dtype=dtype) 42 43 44@pytest.fixture 45def data_missing(dtype): 46 return pd.array([np.nan, True], dtype=dtype) 47 48 49@pytest.fixture 50def data_for_sorting(dtype): 51 return pd.array([True, True, False], dtype=dtype) 52 53 54@pytest.fixture 55def data_missing_for_sorting(dtype): 56 return pd.array([True, np.nan, False], dtype=dtype) 57 58 59@pytest.fixture 60def na_cmp(): 61 # we are pd.NA 62 return lambda x, y: x is pd.NA and y is pd.NA 63 64 65@pytest.fixture 66def na_value(): 67 return pd.NA 68 69 70@pytest.fixture 71def data_for_grouping(dtype): 72 b = True 73 a = False 74 na = np.nan 75 return pd.array([b, b, na, na, a, a, b], dtype=dtype) 76 77 78class TestDtype(base.BaseDtypeTests): 79 pass 80 81 82class TestInterface(base.BaseInterfaceTests): 83 pass 84 85 86class TestConstructors(base.BaseConstructorsTests): 87 pass 88 89 90class TestGetitem(base.BaseGetitemTests): 91 pass 92 93 94class TestSetitem(base.BaseSetitemTests): 95 pass 96 97 98class TestMissing(base.BaseMissingTests): 99 pass 100 101 102class TestArithmeticOps(base.BaseArithmeticOpsTests): 103 implements = {"__sub__", "__rsub__"} 104 105 def check_opname(self, s, op_name, other, exc=None): 106 # overwriting to indicate ops don't raise an error 107 super().check_opname(s, op_name, other, exc=None) 108 109 def _check_op(self, s, op, other, op_name, exc=NotImplementedError): 110 if exc is None: 111 if op_name in self.implements: 112 msg = r"numpy boolean subtract" 113 with pytest.raises(TypeError, match=msg): 114 op(s, other) 115 return 116 117 result = op(s, other) 118 expected = s.combine(other, op) 119 120 if op_name in ( 121 "__floordiv__", 122 "__rfloordiv__", 123 "__pow__", 124 "__rpow__", 125 "__mod__", 126 "__rmod__", 127 ): 128 # combine keeps boolean type 129 expected = expected.astype("Int8") 130 elif op_name in ("__truediv__", "__rtruediv__"): 131 # combine with bools does not generate the correct result 132 # (numpy behaviour for div is to regard the bools as numeric) 133 expected = s.astype(float).combine(other, op).astype("Float64") 134 if op_name == "__rpow__": 135 # for rpow, combine does not propagate NaN 136 expected[result.isna()] = np.nan 137 self.assert_series_equal(result, expected) 138 else: 139 with pytest.raises(exc): 140 op(s, other) 141 142 def _check_divmod_op(self, s, op, other, exc=None): 143 # override to not raise an error 144 super()._check_divmod_op(s, op, other, None) 145 146 @pytest.mark.skip(reason="BooleanArray does not error on ops") 147 def test_error(self, data, all_arithmetic_operators): 148 # other specific errors tested in the boolean array specific tests 149 pass 150 151 def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): 152 # frame & scalar 153 op_name = all_arithmetic_operators 154 if op_name not in self.implements: 155 mark = pytest.mark.xfail(reason="_reduce needs implementation") 156 request.node.add_marker(mark) 157 super().test_arith_frame_with_scalar(data, all_arithmetic_operators) 158 159 160class TestComparisonOps(base.BaseComparisonOpsTests): 161 def check_opname(self, s, op_name, other, exc=None): 162 # overwriting to indicate ops don't raise an error 163 super().check_opname(s, op_name, other, exc=None) 164 165 def _compare_other(self, s, data, op_name, other): 166 self.check_opname(s, op_name, other) 167 168 @pytest.mark.skip(reason="Tested in tests/arrays/test_boolean.py") 169 def test_compare_scalar(self, data, all_compare_operators): 170 pass 171 172 @pytest.mark.skip(reason="Tested in tests/arrays/test_boolean.py") 173 def test_compare_array(self, data, all_compare_operators): 174 pass 175 176 177class TestReshaping(base.BaseReshapingTests): 178 pass 179 180 181class TestMethods(base.BaseMethodsTests): 182 @pytest.mark.parametrize("na_sentinel", [-1, -2]) 183 def test_factorize(self, data_for_grouping, na_sentinel): 184 # override because we only have 2 unique values 185 labels, uniques = pd.factorize(data_for_grouping, na_sentinel=na_sentinel) 186 expected_labels = np.array( 187 [0, 0, na_sentinel, na_sentinel, 1, 1, 0], dtype=np.intp 188 ) 189 expected_uniques = data_for_grouping.take([0, 4]) 190 191 tm.assert_numpy_array_equal(labels, expected_labels) 192 self.assert_extension_array_equal(uniques, expected_uniques) 193 194 def test_combine_le(self, data_repeated): 195 # override because expected needs to be boolean instead of bool dtype 196 orig_data1, orig_data2 = data_repeated(2) 197 s1 = pd.Series(orig_data1) 198 s2 = pd.Series(orig_data2) 199 result = s1.combine(s2, lambda x1, x2: x1 <= x2) 200 expected = pd.Series( 201 [a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))], 202 dtype="boolean", 203 ) 204 self.assert_series_equal(result, expected) 205 206 val = s1.iloc[0] 207 result = s1.combine(val, lambda x1, x2: x1 <= x2) 208 expected = pd.Series([a <= val for a in list(orig_data1)], dtype="boolean") 209 self.assert_series_equal(result, expected) 210 211 def test_searchsorted(self, data_for_sorting, as_series): 212 # override because we only have 2 unique values 213 data_for_sorting = pd.array([True, False], dtype="boolean") 214 b, a = data_for_sorting 215 arr = type(data_for_sorting)._from_sequence([a, b]) 216 217 if as_series: 218 arr = pd.Series(arr) 219 assert arr.searchsorted(a) == 0 220 assert arr.searchsorted(a, side="right") == 1 221 222 assert arr.searchsorted(b) == 1 223 assert arr.searchsorted(b, side="right") == 2 224 225 result = arr.searchsorted(arr.take([0, 1])) 226 expected = np.array([0, 1], dtype=np.intp) 227 228 tm.assert_numpy_array_equal(result, expected) 229 230 # sorter 231 sorter = np.array([1, 0]) 232 assert data_for_sorting.searchsorted(a, sorter=sorter) == 0 233 234 @pytest.mark.skip(reason="uses nullable integer") 235 def test_value_counts(self, all_data, dropna): 236 return super().test_value_counts(all_data, dropna) 237 238 @pytest.mark.skip(reason="uses nullable integer") 239 def test_value_counts_with_normalize(self, data): 240 pass 241 242 def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting): 243 # override because there are only 2 unique values 244 245 # data_for_sorting -> [B, C, A] with A < B < C -> here True, True, False 246 assert data_for_sorting.argmax() == 0 247 assert data_for_sorting.argmin() == 2 248 249 # with repeated values -> first occurence 250 data = data_for_sorting.take([2, 0, 0, 1, 1, 2]) 251 assert data.argmax() == 1 252 assert data.argmin() == 0 253 254 # with missing values 255 # data_missing_for_sorting -> [B, NA, A] with A < B and NA missing. 256 assert data_missing_for_sorting.argmax() == 0 257 assert data_missing_for_sorting.argmin() == 2 258 259 260class TestCasting(base.BaseCastingTests): 261 pass 262 263 264class TestGroupby(base.BaseGroupbyTests): 265 """ 266 Groupby-specific tests are overridden because boolean only has 2 267 unique values, base tests uses 3 groups. 268 """ 269 270 def test_grouping_grouper(self, data_for_grouping): 271 df = pd.DataFrame( 272 {"A": ["B", "B", None, None, "A", "A", "B"], "B": data_for_grouping} 273 ) 274 gr1 = df.groupby("A").grouper.groupings[0] 275 gr2 = df.groupby("B").grouper.groupings[0] 276 277 tm.assert_numpy_array_equal(gr1.grouper, df.A.values) 278 tm.assert_extension_array_equal(gr2.grouper, data_for_grouping) 279 280 @pytest.mark.parametrize("as_index", [True, False]) 281 def test_groupby_extension_agg(self, as_index, data_for_grouping): 282 df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping}) 283 result = df.groupby("B", as_index=as_index).A.mean() 284 _, index = pd.factorize(data_for_grouping, sort=True) 285 286 index = pd.Index(index, name="B") 287 expected = pd.Series([3, 1], index=index, name="A") 288 if as_index: 289 self.assert_series_equal(result, expected) 290 else: 291 expected = expected.reset_index() 292 self.assert_frame_equal(result, expected) 293 294 def test_groupby_agg_extension(self, data_for_grouping): 295 # GH#38980 groupby agg on extension type fails for non-numeric types 296 df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping}) 297 298 expected = df.iloc[[0, 2, 4]] 299 expected = expected.set_index("A") 300 301 result = df.groupby("A").agg({"B": "first"}) 302 self.assert_frame_equal(result, expected) 303 304 result = df.groupby("A").agg("first") 305 self.assert_frame_equal(result, expected) 306 307 result = df.groupby("A").first() 308 self.assert_frame_equal(result, expected) 309 310 def test_groupby_extension_no_sort(self, data_for_grouping): 311 df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping}) 312 result = df.groupby("B", sort=False).A.mean() 313 _, index = pd.factorize(data_for_grouping, sort=False) 314 315 index = pd.Index(index, name="B") 316 expected = pd.Series([1, 3], index=index, name="A") 317 self.assert_series_equal(result, expected) 318 319 def test_groupby_extension_transform(self, data_for_grouping): 320 valid = data_for_grouping[~data_for_grouping.isna()] 321 df = pd.DataFrame({"A": [1, 1, 3, 3, 1], "B": valid}) 322 323 result = df.groupby("B").A.transform(len) 324 expected = pd.Series([3, 3, 2, 2, 3], name="A") 325 326 self.assert_series_equal(result, expected) 327 328 def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op): 329 df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping}) 330 df.groupby("B").apply(groupby_apply_op) 331 df.groupby("B").A.apply(groupby_apply_op) 332 df.groupby("A").apply(groupby_apply_op) 333 df.groupby("A").B.apply(groupby_apply_op) 334 335 def test_groupby_apply_identity(self, data_for_grouping): 336 df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping}) 337 result = df.groupby("A").B.apply(lambda x: x.array) 338 expected = pd.Series( 339 [ 340 df.B.iloc[[0, 1, 6]].array, 341 df.B.iloc[[2, 3]].array, 342 df.B.iloc[[4, 5]].array, 343 ], 344 index=pd.Index([1, 2, 3], name="A"), 345 name="B", 346 ) 347 self.assert_series_equal(result, expected) 348 349 def test_in_numeric_groupby(self, data_for_grouping): 350 df = pd.DataFrame( 351 { 352 "A": [1, 1, 2, 2, 3, 3, 1], 353 "B": data_for_grouping, 354 "C": [1, 1, 1, 1, 1, 1, 1], 355 } 356 ) 357 result = df.groupby("A").sum().columns 358 359 if data_for_grouping.dtype._is_numeric: 360 expected = pd.Index(["B", "C"]) 361 else: 362 expected = pd.Index(["C"]) 363 364 tm.assert_index_equal(result, expected) 365 366 @pytest.mark.parametrize("min_count", [0, 10]) 367 def test_groupby_sum_mincount(self, data_for_grouping, min_count): 368 df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping}) 369 result = df.groupby("A").sum(min_count=min_count) 370 if min_count == 0: 371 expected = pd.DataFrame( 372 {"B": pd.array([3, 0, 0], dtype="Int64")}, 373 index=pd.Index([1, 2, 3], name="A"), 374 ) 375 tm.assert_frame_equal(result, expected) 376 else: 377 expected = pd.DataFrame( 378 {"B": pd.array([pd.NA] * 3, dtype="Int64")}, 379 index=pd.Index([1, 2, 3], name="A"), 380 ) 381 tm.assert_frame_equal(result, expected) 382 383 384class TestNumericReduce(base.BaseNumericReduceTests): 385 def check_reduce(self, s, op_name, skipna): 386 result = getattr(s, op_name)(skipna=skipna) 387 expected = getattr(s.astype("float64"), op_name)(skipna=skipna) 388 # override parent function to cast to bool for min/max 389 if np.isnan(expected): 390 expected = pd.NA 391 elif op_name in ("min", "max"): 392 expected = bool(expected) 393 tm.assert_almost_equal(result, expected) 394 395 396class TestBooleanReduce(base.BaseBooleanReduceTests): 397 pass 398 399 400class TestPrinting(base.BasePrintingTests): 401 pass 402 403 404class TestUnaryOps(base.BaseUnaryOpsTests): 405 pass 406 407 408class TestParsing(base.BaseParsingTests): 409 pass 410