1# Copyright (C) 2016-present the asyncpg authors and contributors
2# <see AUTHORS file>
3#
4# This module is part of asyncpg and is released under
5# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
6
7
8import builtins
9import sys
10import typing
11
12if sys.version_info >= (3, 8):
13    from typing import Literal, SupportsIndex
14else:
15    from typing_extensions import Literal, SupportsIndex
16
17
18__all__ = (
19    'BitString', 'Point', 'Path', 'Polygon',
20    'Box', 'Line', 'LineSegment', 'Circle',
21)
22
23_BitString = typing.TypeVar('_BitString', bound='BitString')
24_BitOrderType = Literal['big', 'little']
25
26
27class BitString:
28    """Immutable representation of PostgreSQL `bit` and `varbit` types."""
29
30    __slots__ = '_bytes', '_bitlength'
31
32    def __init__(self,
33                 bitstring: typing.Optional[builtins.bytes] = None) -> None:
34        if not bitstring:
35            self._bytes = bytes()
36            self._bitlength = 0
37        else:
38            bytelen = len(bitstring) // 8 + 1
39            bytes_ = bytearray(bytelen)
40            byte = 0
41            byte_pos = 0
42            bit_pos = 0
43
44            for i, bit in enumerate(bitstring):
45                if bit == ' ':  # type: ignore
46                    continue
47                bit = int(bit)
48                if bit != 0 and bit != 1:
49                    raise ValueError(
50                        'invalid bit value at position {}'.format(i))
51
52                byte |= bit << (8 - bit_pos - 1)
53                bit_pos += 1
54                if bit_pos == 8:
55                    bytes_[byte_pos] = byte
56                    byte = 0
57                    byte_pos += 1
58                    bit_pos = 0
59
60            if bit_pos != 0:
61                bytes_[byte_pos] = byte
62
63            bitlen = byte_pos * 8 + bit_pos
64            bytelen = byte_pos + (1 if bit_pos else 0)
65
66            self._bytes = bytes(bytes_[:bytelen])
67            self._bitlength = bitlen
68
69    @classmethod
70    def frombytes(cls: typing.Type[_BitString],
71                  bytes_: typing.Optional[builtins.bytes] = None,
72                  bitlength: typing.Optional[int] = None) -> _BitString:
73        if bitlength is None:
74            if bytes_ is None:
75                bytes_ = bytes()
76                bitlength = 0
77            else:
78                bitlength = len(bytes_) * 8
79        else:
80            if bytes_ is None:
81                bytes_ = bytes(bitlength // 8 + 1)
82                bitlength = bitlength
83            else:
84                bytes_len = len(bytes_) * 8
85
86                if bytes_len == 0 and bitlength != 0:
87                    raise ValueError('invalid bit length specified')
88
89                if bytes_len != 0 and bitlength == 0:
90                    raise ValueError('invalid bit length specified')
91
92                if bitlength < bytes_len - 8:
93                    raise ValueError('invalid bit length specified')
94
95                if bitlength > bytes_len:
96                    raise ValueError('invalid bit length specified')
97
98        result = cls()
99        result._bytes = bytes_
100        result._bitlength = bitlength
101
102        return result
103
104    @property
105    def bytes(self) -> builtins.bytes:
106        return self._bytes
107
108    def as_string(self) -> str:
109        s = ''
110
111        for i in range(self._bitlength):
112            s += str(self._getitem(i))
113            if i % 4 == 3:
114                s += ' '
115
116        return s.strip()
117
118    def to_int(self, bitorder: _BitOrderType = 'big',
119               *, signed: bool = False) -> int:
120        """Interpret the BitString as a Python int.
121        Acts similarly to int.from_bytes.
122
123        :param bitorder:
124            Determines the bit order used to interpret the BitString. By
125            default, this function uses Postgres conventions for casting bits
126            to ints. If bitorder is 'big', the most significant bit is at the
127            start of the string (this is the same as the default). If bitorder
128            is 'little', the most significant bit is at the end of the string.
129
130        :param bool signed:
131            Determines whether two's complement is used to interpret the
132            BitString. If signed is False, the returned value is always
133            non-negative.
134
135        :return int: An integer representing the BitString. Information about
136                     the BitString's exact length is lost.
137
138        .. versionadded:: 0.18.0
139        """
140        x = int.from_bytes(self._bytes, byteorder='big')
141        x >>= -self._bitlength % 8
142        if bitorder == 'big':
143            pass
144        elif bitorder == 'little':
145            x = int(bin(x)[:1:-1].ljust(self._bitlength, '0'), 2)
146        else:
147            raise ValueError("bitorder must be either 'big' or 'little'")
148
149        if signed and self._bitlength > 0 and x & (1 << (self._bitlength - 1)):
150            x -= 1 << self._bitlength
151        return x
152
153    @classmethod
154    def from_int(cls: typing.Type[_BitString], x: int, length: int,
155                 bitorder: _BitOrderType = 'big', *, signed: bool = False) \
156            -> _BitString:
157        """Represent the Python int x as a BitString.
158        Acts similarly to int.to_bytes.
159
160        :param int x:
161            An integer to represent. Negative integers are represented in two's
162            complement form, unless the argument signed is False, in which case
163            negative integers raise an OverflowError.
164
165        :param int length:
166            The length of the resulting BitString. An OverflowError is raised
167            if the integer is not representable in this many bits.
168
169        :param bitorder:
170            Determines the bit order used in the BitString representation. By
171            default, this function uses Postgres conventions for casting ints
172            to bits. If bitorder is 'big', the most significant bit is at the
173            start of the string (this is the same as the default). If bitorder
174            is 'little', the most significant bit is at the end of the string.
175
176        :param bool signed:
177            Determines whether two's complement is used in the BitString
178            representation. If signed is False and a negative integer is given,
179            an OverflowError is raised.
180
181        :return BitString: A BitString representing the input integer, in the
182                           form specified by the other input args.
183
184        .. versionadded:: 0.18.0
185        """
186        # Exception types are by analogy to int.to_bytes
187        if length < 0:
188            raise ValueError("length argument must be non-negative")
189        elif length < x.bit_length():
190            raise OverflowError("int too big to convert")
191
192        if x < 0:
193            if not signed:
194                raise OverflowError("can't convert negative int to unsigned")
195            x &= (1 << length) - 1
196
197        if bitorder == 'big':
198            pass
199        elif bitorder == 'little':
200            x = int(bin(x)[:1:-1].ljust(length, '0'), 2)
201        else:
202            raise ValueError("bitorder must be either 'big' or 'little'")
203
204        x <<= (-length % 8)
205        bytes_ = x.to_bytes((length + 7) // 8, byteorder='big')
206        return cls.frombytes(bytes_, length)
207
208    def __repr__(self) -> str:
209        return '<BitString {}>'.format(self.as_string())
210
211    __str__ = __repr__
212
213    def __eq__(self, other: object) -> bool:
214        if not isinstance(other, BitString):
215            return NotImplemented
216
217        return (self._bytes == other._bytes and
218                self._bitlength == other._bitlength)
219
220    def __hash__(self) -> int:
221        return hash((self._bytes, self._bitlength))
222
223    def _getitem(self, i: int) -> int:
224        byte = self._bytes[i // 8]
225        shift = 8 - i % 8 - 1
226        return (byte >> shift) & 0x1
227
228    def __getitem__(self, i: int) -> int:
229        if isinstance(i, slice):
230            raise NotImplementedError('BitString does not support slices')
231
232        if i >= self._bitlength:
233            raise IndexError('index out of range')
234
235        return self._getitem(i)
236
237    def __len__(self) -> int:
238        return self._bitlength
239
240
241class Point(typing.Tuple[float, float]):
242    """Immutable representation of PostgreSQL `point` type."""
243
244    __slots__ = ()
245
246    def __new__(cls,
247                x: typing.Union[typing.SupportsFloat,
248                                SupportsIndex,
249                                typing.Text,
250                                builtins.bytes,
251                                builtins.bytearray],
252                y: typing.Union[typing.SupportsFloat,
253                                SupportsIndex,
254                                typing.Text,
255                                builtins.bytes,
256                                builtins.bytearray]) -> 'Point':
257        return super().__new__(cls,
258                               typing.cast(typing.Any, (float(x), float(y))))
259
260    def __repr__(self) -> str:
261        return '{}.{}({})'.format(
262            type(self).__module__,
263            type(self).__name__,
264            tuple.__repr__(self)
265        )
266
267    @property
268    def x(self) -> float:
269        return self[0]
270
271    @property
272    def y(self) -> float:
273        return self[1]
274
275
276class Box(typing.Tuple[Point, Point]):
277    """Immutable representation of PostgreSQL `box` type."""
278
279    __slots__ = ()
280
281    def __new__(cls, high: typing.Sequence[float],
282                low: typing.Sequence[float]) -> 'Box':
283        return super().__new__(cls,
284                               typing.cast(typing.Any, (Point(*high),
285                                                        Point(*low))))
286
287    def __repr__(self) -> str:
288        return '{}.{}({})'.format(
289            type(self).__module__,
290            type(self).__name__,
291            tuple.__repr__(self)
292        )
293
294    @property
295    def high(self) -> Point:
296        return self[0]
297
298    @property
299    def low(self) -> Point:
300        return self[1]
301
302
303class Line(typing.Tuple[float, float, float]):
304    """Immutable representation of PostgreSQL `line` type."""
305
306    __slots__ = ()
307
308    def __new__(cls, A: float, B: float, C: float) -> 'Line':
309        return super().__new__(cls, typing.cast(typing.Any, (A, B, C)))
310
311    @property
312    def A(self) -> float:
313        return self[0]
314
315    @property
316    def B(self) -> float:
317        return self[1]
318
319    @property
320    def C(self) -> float:
321        return self[2]
322
323
324class LineSegment(typing.Tuple[Point, Point]):
325    """Immutable representation of PostgreSQL `lseg` type."""
326
327    __slots__ = ()
328
329    def __new__(cls, p1: typing.Sequence[float],
330                p2: typing.Sequence[float]) -> 'LineSegment':
331        return super().__new__(cls,
332                               typing.cast(typing.Any, (Point(*p1),
333                                                        Point(*p2))))
334
335    def __repr__(self) -> str:
336        return '{}.{}({})'.format(
337            type(self).__module__,
338            type(self).__name__,
339            tuple.__repr__(self)
340        )
341
342    @property
343    def p1(self) -> Point:
344        return self[0]
345
346    @property
347    def p2(self) -> Point:
348        return self[1]
349
350
351class Path:
352    """Immutable representation of PostgreSQL `path` type."""
353
354    __slots__ = '_is_closed', 'points'
355
356    def __init__(self, *points: typing.Sequence[float],
357                 is_closed: bool = False) -> None:
358        self.points = tuple(Point(*p) for p in points)
359        self._is_closed = is_closed
360
361    @property
362    def is_closed(self) -> bool:
363        return self._is_closed
364
365    def __eq__(self, other: object) -> bool:
366        if not isinstance(other, Path):
367            return NotImplemented
368
369        return (self.points == other.points and
370                self._is_closed == other._is_closed)
371
372    def __hash__(self) -> int:
373        return hash((self.points, self.is_closed))
374
375    def __iter__(self) -> typing.Iterator[Point]:
376        return iter(self.points)
377
378    def __len__(self) -> int:
379        return len(self.points)
380
381    @typing.overload
382    def __getitem__(self, i: int) -> Point:
383        ...
384
385    @typing.overload
386    def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]:
387        ...
388
389    def __getitem__(self, i: typing.Union[int, slice]) \
390            -> typing.Union[Point, typing.Tuple[Point, ...]]:
391        return self.points[i]
392
393    def __contains__(self, point: object) -> bool:
394        return point in self.points
395
396
397class Polygon(Path):
398    """Immutable representation of PostgreSQL `polygon` type."""
399
400    __slots__ = ()
401
402    def __init__(self, *points: typing.Sequence[float]) -> None:
403        # polygon is always closed
404        super().__init__(*points, is_closed=True)
405
406
407class Circle(typing.Tuple[Point, float]):
408    """Immutable representation of PostgreSQL `circle` type."""
409
410    __slots__ = ()
411
412    def __new__(cls, center: Point, radius: float) -> 'Circle':
413        return super().__new__(cls, typing.cast(typing.Any, (center, radius)))
414
415    @property
416    def center(self) -> Point:
417        return self[0]
418
419    @property
420    def radius(self) -> float:
421        return self[1]
422