1"""
2This file contains a minimal set of tests for compliance with the extension
3array interface test suite (by inheriting the pandas test suite), and should
4contain no other tests.
5Other tests (eg related to the spatial functionality or integration
6with GeoSeries/GeoDataFrame) should be added to test_array.py and others.
7
8The tests in this file are inherited from the BaseExtensionTests, and only
9minimal tweaks should be applied to get the tests passing (by overwriting a
10parent method).
11
12A set of fixtures are defined to provide data for the tests (the fixtures
13expected to be available to pytest by the inherited pandas tests).
14
15"""
16import operator
17
18import numpy as np
19from numpy.testing import assert_array_equal
20import pandas as pd
21from pandas.tests.extension import base as extension_tests
22
23import shapely.geometry
24
25from geopandas.array import GeometryArray, GeometryDtype, from_shapely
26from geopandas._compat import ignore_shapely2_warnings
27
28import pytest
29
30# -----------------------------------------------------------------------------
31# Compat with extension tests in older pandas versions
32# -----------------------------------------------------------------------------
33
34
35not_yet_implemented = pytest.mark.skip(reason="Not yet implemented")
36no_sorting = pytest.mark.skip(reason="Sorting not supported")
37
38
39# -----------------------------------------------------------------------------
40# Required fixtures
41# -----------------------------------------------------------------------------
42
43
44@pytest.fixture
45def dtype():
46    """A fixture providing the ExtensionDtype to validate."""
47    return GeometryDtype()
48
49
50def make_data():
51    a = np.empty(100, dtype=object)
52    with ignore_shapely2_warnings():
53        a[:] = [shapely.geometry.Point(i, i) for i in range(100)]
54    ga = from_shapely(a)
55    return ga
56
57
58@pytest.fixture
59def data():
60    """Length-100 array for this type.
61
62    * data[0] and data[1] should both be non missing
63    * data[0] and data[1] should not be equal
64    """
65    return make_data()
66
67
68@pytest.fixture
69def data_for_twos():
70    """Length-100 array in which all the elements are two."""
71    raise NotImplementedError
72
73
74@pytest.fixture
75def data_missing():
76    """Length-2 array with [NA, Valid]"""
77    return from_shapely([None, shapely.geometry.Point(1, 1)])
78
79
80@pytest.fixture(params=["data", "data_missing"])
81def all_data(request, data, data_missing):
82    """Parametrized fixture giving 'data' and 'data_missing'"""
83    if request.param == "data":
84        return data
85    elif request.param == "data_missing":
86        return data_missing
87
88
89@pytest.fixture
90def data_repeated(data):
91    """
92    Generate many datasets.
93
94    Parameters
95    ----------
96    data : fixture implementing `data`
97
98    Returns
99    -------
100    Callable[[int], Generator]:
101        A callable that takes a `count` argument and
102        returns a generator yielding `count` datasets.
103    """
104
105    def gen(count):
106        for _ in range(count):
107            yield data
108
109    return gen
110
111
112@pytest.fixture
113def data_for_sorting():
114    """Length-3 array with a known sort order.
115
116    This should be three items [B, C, A] with
117    A < B < C
118    """
119    raise NotImplementedError
120
121
122@pytest.fixture
123def data_missing_for_sorting():
124    """Length-3 array with a known sort order.
125
126    This should be three items [B, NA, A] with
127    A < B and NA missing.
128    """
129    raise NotImplementedError
130
131
132@pytest.fixture
133def na_cmp():
134    """Binary operator for comparing NA values.
135    Should return a function of two arguments that returns
136    True if both arguments are (scalar) NA for your type.
137    By default, uses ``operator.or``
138    """
139    return lambda x, y: x is None and y is None
140
141
142@pytest.fixture
143def na_value():
144    """The scalar missing value for this type. Default 'None'"""
145    return None
146
147
148@pytest.fixture
149def data_for_grouping():
150    """Data for factorization, grouping, and unique tests.
151
152    Expected to be like [B, B, NA, NA, A, A, B, C]
153
154    Where A < B < C and NA is missing
155    """
156    return from_shapely(
157        [
158            shapely.geometry.Point(1, 1),
159            shapely.geometry.Point(1, 1),
160            None,
161            None,
162            shapely.geometry.Point(0, 0),
163            shapely.geometry.Point(0, 0),
164            shapely.geometry.Point(1, 1),
165            shapely.geometry.Point(2, 2),
166        ]
167    )
168
169
170@pytest.fixture(params=[True, False])
171def box_in_series(request):
172    """Whether to box the data in a Series"""
173    return request.param
174
175
176@pytest.fixture(
177    params=[
178        lambda x: 1,
179        lambda x: [1] * len(x),
180        lambda x: pd.Series([1] * len(x)),
181        lambda x: x,
182    ],
183    ids=["scalar", "list", "series", "object"],
184)
185def groupby_apply_op(request):
186    """
187    Functions to test groupby.apply().
188    """
189    return request.param
190
191
192@pytest.fixture(params=[True, False])
193def as_frame(request):
194    """
195    Boolean fixture to support Series and Series.to_frame() comparison testing.
196    """
197    return request.param
198
199
200@pytest.fixture(params=[True, False])
201def as_series(request):
202    """
203    Boolean fixture to support arr and Series(arr) comparison testing.
204    """
205    return request.param
206
207
208@pytest.fixture(params=[True, False])
209def use_numpy(request):
210    """
211    Boolean fixture to support comparison testing of ExtensionDtype array
212    and numpy array.
213    """
214    return request.param
215
216
217@pytest.fixture(params=["ffill", "bfill"])
218def fillna_method(request):
219    """
220    Parametrized fixture giving method parameters 'ffill' and 'bfill' for
221    Series.fillna(method=<method>) testing.
222    """
223    return request.param
224
225
226@pytest.fixture(params=[True, False])
227def as_array(request):
228    """
229    Boolean fixture to support ExtensionDtype _from_sequence method testing.
230    """
231    return request.param
232
233
234# Fixtures defined in pandas/conftest.py that are also needed: defining them
235# here instead of importing for compatibility
236
237
238@pytest.fixture(
239    params=["sum", "max", "min", "mean", "prod", "std", "var", "median", "kurt", "skew"]
240)
241def all_numeric_reductions(request):
242    """
243    Fixture for numeric reduction names
244    """
245    return request.param
246
247
248@pytest.fixture(params=["all", "any"])
249def all_boolean_reductions(request):
250    """
251    Fixture for boolean reduction names
252    """
253    return request.param
254
255
256# only == and != are support for GeometryArray
257# @pytest.fixture(params=["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"])
258@pytest.fixture(params=["__eq__", "__ne__"])
259def all_compare_operators(request):
260    """
261    Fixture for dunder names for common compare operations
262
263    * >=
264    * >
265    * ==
266    * !=
267    * <
268    * <=
269    """
270    return request.param
271
272
273# -----------------------------------------------------------------------------
274# Inherited tests
275# -----------------------------------------------------------------------------
276
277
278class TestDtype(extension_tests.BaseDtypeTests):
279
280    # additional tests
281
282    def test_array_type_with_arg(self, data, dtype):
283        assert dtype.construct_array_type() is GeometryArray
284
285    def test_registry(self, data, dtype):
286        s = pd.Series(np.asarray(data), dtype=object)
287        result = s.astype("geometry")
288        assert isinstance(result.array, GeometryArray)
289        expected = pd.Series(data)
290        self.assert_series_equal(result, expected)
291
292
293class TestInterface(extension_tests.BaseInterfaceTests):
294    def test_array_interface(self, data):
295        # we are overriding this base test because the creation of `expected`
296        # potentionally doesn't work for shapely geometries
297        # TODO can be removed with Shapely 2.0
298        result = np.array(data)
299        assert result[0] == data[0]
300
301        result = np.array(data, dtype=object)
302        # expected = np.array(list(data), dtype=object)
303        expected = np.empty(len(data), dtype=object)
304        with ignore_shapely2_warnings():
305            expected[:] = list(data)
306        assert_array_equal(result, expected)
307
308    def test_contains(self, data, data_missing):
309        # overridden due to the inconsistency between
310        # GeometryDtype.na_value = np.nan
311        # and None being used as NA in array
312
313        # ensure data without missing values
314        data = data[~data.isna()]
315
316        # first elements are non-missing
317        assert data[0] in data
318        assert data_missing[0] in data_missing
319
320        assert None in data_missing
321        assert None not in data
322        assert pd.NaT not in data_missing
323
324
325class TestConstructors(extension_tests.BaseConstructorsTests):
326    pass
327
328
329class TestReshaping(extension_tests.BaseReshapingTests):
330    pass
331
332
333class TestGetitem(extension_tests.BaseGetitemTests):
334    pass
335
336
337class TestSetitem(extension_tests.BaseSetitemTests):
338    pass
339
340
341class TestMissing(extension_tests.BaseMissingTests):
342    def test_fillna_series(self, data_missing):
343        fill_value = data_missing[1]
344        ser = pd.Series(data_missing)
345
346        result = ser.fillna(fill_value)
347        expected = pd.Series(data_missing._from_sequence([fill_value, fill_value]))
348        self.assert_series_equal(result, expected)
349
350        # filling with array-like not yet supported
351
352        # # Fill with a series
353        # result = ser.fillna(expected)
354        # self.assert_series_equal(result, expected)
355
356        # # Fill with a series not affecting the missing values
357        # result = ser.fillna(ser)
358        # self.assert_series_equal(result, ser)
359
360    @pytest.mark.skip("fillna method not supported")
361    def test_fillna_limit_pad(self, data_missing):
362        pass
363
364    @pytest.mark.skip("fillna method not supported")
365    def test_fillna_limit_backfill(self, data_missing):
366        pass
367
368    @pytest.mark.skip("fillna method not supported")
369    def test_fillna_series_method(self, data_missing, method):
370        pass
371
372    @pytest.mark.skip("fillna method not supported")
373    def test_fillna_no_op_returns_copy(self, data):
374        pass
375
376
377class TestReduce(extension_tests.BaseNoReduceTests):
378    @pytest.mark.skip("boolean reduce (any/all) tested in test_pandas_methods")
379    def test_reduce_series_boolean():
380        pass
381
382
383_all_arithmetic_operators = [
384    "__add__",
385    "__radd__",
386    # '__sub__', '__rsub__',
387    "__mul__",
388    "__rmul__",
389    "__floordiv__",
390    "__rfloordiv__",
391    "__truediv__",
392    "__rtruediv__",
393    "__pow__",
394    "__rpow__",
395    "__mod__",
396    "__rmod__",
397]
398
399
400@pytest.fixture(params=_all_arithmetic_operators)
401def all_arithmetic_operators(request):
402    """
403    Fixture for dunder names for common arithmetic operations
404
405    Adapted to exclude __sub__, as this is implemented as "difference".
406    """
407    return request.param
408
409
410# an inherited test from pandas creates a Series from a list of geometries, which
411# triggers the warning from Shapely, out of control of GeoPandas, so ignoring here
412@pytest.mark.filterwarnings(
413    "ignore:The array interface is deprecated and will no longer work in Shapely 2.0"
414)
415class TestArithmeticOps(extension_tests.BaseArithmeticOpsTests):
416    @pytest.mark.skip(reason="not applicable")
417    def test_divmod_series_array(self, data, data_for_twos):
418        pass
419
420    @pytest.mark.skip(reason="not applicable")
421    def test_add_series_with_extension_array(self, data):
422        pass
423
424
425# an inherited test from pandas creates a Series from a list of geometries, which
426# triggers the warning from Shapely, out of control of GeoPandas, so ignoring here
427@pytest.mark.filterwarnings(
428    "ignore:The array interface is deprecated and will no longer work in Shapely 2.0"
429)
430class TestComparisonOps(extension_tests.BaseComparisonOpsTests):
431    def _compare_other(self, s, data, op_name, other):
432        op = getattr(operator, op_name.strip("_"))
433        result = op(s, other)
434        expected = s.combine(other, op)
435        self.assert_series_equal(result, expected)
436
437    def test_compare_scalar(self, data, all_compare_operators):  # noqa
438        op_name = all_compare_operators
439        s = pd.Series(data)
440        self._compare_other(s, data, op_name, data[0])
441
442    def test_compare_array(self, data, all_compare_operators):  # noqa
443        op_name = all_compare_operators
444        s = pd.Series(data)
445        other = pd.Series([data[0]] * len(data))
446        self._compare_other(s, data, op_name, other)
447
448
449class TestMethods(extension_tests.BaseMethodsTests):
450    @no_sorting
451    @pytest.mark.parametrize("dropna", [True, False])
452    def test_value_counts(self, all_data, dropna):
453        pass
454
455    @no_sorting
456    def test_value_counts_with_normalize(self, data):
457        pass
458
459    @no_sorting
460    def test_argsort(self, data_for_sorting):
461        result = pd.Series(data_for_sorting).argsort()
462        expected = pd.Series(np.array([2, 0, 1], dtype=np.int64))
463        self.assert_series_equal(result, expected)
464
465    @no_sorting
466    def test_argsort_missing(self, data_missing_for_sorting):
467        result = pd.Series(data_missing_for_sorting).argsort()
468        expected = pd.Series(np.array([1, -1, 0], dtype=np.int64))
469        self.assert_series_equal(result, expected)
470
471    @no_sorting
472    @pytest.mark.parametrize("ascending", [True, False])
473    def test_sort_values(self, data_for_sorting, ascending):
474        ser = pd.Series(data_for_sorting)
475        result = ser.sort_values(ascending=ascending)
476        expected = ser.iloc[[2, 0, 1]]
477        if not ascending:
478            expected = expected[::-1]
479
480        self.assert_series_equal(result, expected)
481
482    @no_sorting
483    @pytest.mark.parametrize("ascending", [True, False])
484    def test_sort_values_missing(self, data_missing_for_sorting, ascending):
485        ser = pd.Series(data_missing_for_sorting)
486        result = ser.sort_values(ascending=ascending)
487        if ascending:
488            expected = ser.iloc[[2, 0, 1]]
489        else:
490            expected = ser.iloc[[0, 2, 1]]
491        self.assert_series_equal(result, expected)
492
493    @no_sorting
494    @pytest.mark.parametrize("ascending", [True, False])
495    def test_sort_values_frame(self, data_for_sorting, ascending):
496        df = pd.DataFrame({"A": [1, 2, 1], "B": data_for_sorting})
497        result = df.sort_values(["A", "B"])
498        expected = pd.DataFrame(
499            {"A": [1, 1, 2], "B": data_for_sorting.take([2, 0, 1])}, index=[2, 0, 1]
500        )
501        self.assert_frame_equal(result, expected)
502
503    @no_sorting
504    def test_searchsorted(self, data_for_sorting, as_series):
505        pass
506
507    @not_yet_implemented
508    def test_combine_le(self):
509        pass
510
511    @pytest.mark.skip(reason="addition not supported")
512    def test_combine_add(self):
513        pass
514
515    @not_yet_implemented
516    def test_fillna_length_mismatch(self, data_missing):
517        msg = "Length of 'value' does not match."
518        with pytest.raises(ValueError, match=msg):
519            data_missing.fillna(data_missing.take([1]))
520
521    @no_sorting
522    def test_nargsort(self):
523        pass
524
525    @no_sorting
526    def test_argsort_missing_array(self):
527        pass
528
529    @no_sorting
530    def test_argmin_argmax(self):
531        pass
532
533    @no_sorting
534    def test_argmin_argmax_empty_array(self):
535        pass
536
537    @no_sorting
538    def test_argmin_argmax_all_na(self):
539        pass
540
541    @no_sorting
542    def test_argreduce_series(self):
543        pass
544
545    @no_sorting
546    def test_argmax_argmin_no_skipna_notimplemented(self):
547        pass
548
549
550class TestCasting(extension_tests.BaseCastingTests):
551    pass
552
553
554class TestGroupby(extension_tests.BaseGroupbyTests):
555    @no_sorting
556    @pytest.mark.parametrize("as_index", [True, False])
557    def test_groupby_extension_agg(self, as_index, data_for_grouping):
558        pass
559
560    @no_sorting
561    def test_groupby_extension_transform(self, data_for_grouping):
562        pass
563
564    @no_sorting
565    @pytest.mark.parametrize(
566        "op",
567        [
568            lambda x: 1,
569            lambda x: [1] * len(x),
570            lambda x: pd.Series([1] * len(x)),
571            lambda x: x,
572        ],
573        ids=["scalar", "list", "series", "object"],
574    )
575    def test_groupby_extension_apply(self, data_for_grouping, op):
576        pass
577
578
579class TestPrinting(extension_tests.BasePrintingTests):
580    pass
581
582
583@not_yet_implemented
584class TestParsing(extension_tests.BaseParsingTests):
585    pass
586