1"""
2Rudimentary Apache Arrow-backed ExtensionArray.
3
4At the moment, just a boolean array / type is implemented.
5Eventually, we'll want to parametrize the type and support
6multiple dtypes. Not all methods are implemented yet, and the
7current implementation is not efficient.
8"""
9import copy
10import itertools
11import operator
12from typing import Type
13
14import numpy as np
15import pyarrow as pa
16
17import pandas as pd
18from pandas.api.extensions import (
19    ExtensionArray,
20    ExtensionDtype,
21    register_extension_dtype,
22    take,
23)
24from pandas.core.arraylike import OpsMixin
25
26
27@register_extension_dtype
28class ArrowBoolDtype(ExtensionDtype):
29
30    type = np.bool_
31    kind = "b"
32    name = "arrow_bool"
33    na_value = pa.NULL
34
35    @classmethod
36    def construct_array_type(cls) -> Type["ArrowBoolArray"]:
37        """
38        Return the array type associated with this dtype.
39
40        Returns
41        -------
42        type
43        """
44        return ArrowBoolArray
45
46    @property
47    def _is_boolean(self) -> bool:
48        return True
49
50
51@register_extension_dtype
52class ArrowStringDtype(ExtensionDtype):
53
54    type = str
55    kind = "U"
56    name = "arrow_string"
57    na_value = pa.NULL
58
59    @classmethod
60    def construct_array_type(cls) -> Type["ArrowStringArray"]:
61        """
62        Return the array type associated with this dtype.
63
64        Returns
65        -------
66        type
67        """
68        return ArrowStringArray
69
70
71class ArrowExtensionArray(OpsMixin, ExtensionArray):
72    _data: pa.ChunkedArray
73
74    @classmethod
75    def from_scalars(cls, values):
76        arr = pa.chunked_array([pa.array(np.asarray(values))])
77        return cls(arr)
78
79    @classmethod
80    def from_array(cls, arr):
81        assert isinstance(arr, pa.Array)
82        return cls(pa.chunked_array([arr]))
83
84    @classmethod
85    def _from_sequence(cls, scalars, dtype=None, copy=False):
86        return cls.from_scalars(scalars)
87
88    def __repr__(self):
89        return f"{type(self).__name__}({repr(self._data)})"
90
91    def __getitem__(self, item):
92        if pd.api.types.is_scalar(item):
93            return self._data.to_pandas()[item]
94        else:
95            vals = self._data.to_pandas()[item]
96            return type(self).from_scalars(vals)
97
98    def __len__(self):
99        return len(self._data)
100
101    def astype(self, dtype, copy=True):
102        # needed to fix this astype for the Series constructor.
103        if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
104            if copy:
105                return self.copy()
106            return self
107        return super().astype(dtype, copy)
108
109    @property
110    def dtype(self):
111        return self._dtype
112
113    def _logical_method(self, other, op):
114        if not isinstance(other, type(self)):
115            raise NotImplementedError()
116
117        result = op(np.array(self._data), np.array(other._data))
118        return ArrowBoolArray(
119            pa.chunked_array([pa.array(result, mask=pd.isna(self._data.to_pandas()))])
120        )
121
122    def __eq__(self, other):
123        if not isinstance(other, type(self)):
124            return False
125
126        return self._logical_method(other, operator.eq)
127
128    @property
129    def nbytes(self) -> int:
130        return sum(
131            x.size
132            for chunk in self._data.chunks
133            for x in chunk.buffers()
134            if x is not None
135        )
136
137    def isna(self):
138        nas = pd.isna(self._data.to_pandas())
139        return type(self).from_scalars(nas)
140
141    def take(self, indices, allow_fill=False, fill_value=None):
142        data = self._data.to_pandas()
143
144        if allow_fill and fill_value is None:
145            fill_value = self.dtype.na_value
146
147        result = take(data, indices, fill_value=fill_value, allow_fill=allow_fill)
148        return self._from_sequence(result, dtype=self.dtype)
149
150    def copy(self):
151        return type(self)(copy.copy(self._data))
152
153    @classmethod
154    def _concat_same_type(cls, to_concat):
155        chunks = list(itertools.chain.from_iterable(x._data.chunks for x in to_concat))
156        arr = pa.chunked_array(chunks)
157        return cls(arr)
158
159    def __invert__(self):
160        return type(self).from_scalars(~self._data.to_pandas())
161
162    def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
163        if skipna:
164            arr = self[~self.isna()]
165        else:
166            arr = self
167
168        try:
169            op = getattr(arr, name)
170        except AttributeError as err:
171            raise TypeError from err
172        return op(**kwargs)
173
174    def any(self, axis=0, out=None):
175        # Explicitly return a plain bool to reproduce GH-34660
176        return bool(self._data.to_pandas().any())
177
178    def all(self, axis=0, out=None):
179        # Explicitly return a plain bool to reproduce GH-34660
180        return bool(self._data.to_pandas().all())
181
182
183class ArrowBoolArray(ArrowExtensionArray):
184    def __init__(self, values):
185        if not isinstance(values, pa.ChunkedArray):
186            raise ValueError
187
188        assert values.type == pa.bool_()
189        self._data = values
190        self._dtype = ArrowBoolDtype()
191
192
193class ArrowStringArray(ArrowExtensionArray):
194    def __init__(self, values):
195        if not isinstance(values, pa.ChunkedArray):
196            raise ValueError
197
198        assert values.type == pa.string()
199        self._data = values
200        self._dtype = ArrowStringDtype()
201