1# Copyright (C) 2016-present the asyncpg authors and contributors
2# <see AUTHORS file>
3#
4# This module is part of asyncpg and is released under
5# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
6
7
8from asyncpg import types as apg_types
9
10from collections.abc import Sequence as SequenceABC
11
12# defined in postgresql/src/include/utils/rangetypes.h
13DEF RANGE_EMPTY  = 0x01  # range is empty
14DEF RANGE_LB_INC = 0x02  # lower bound is inclusive
15DEF RANGE_UB_INC = 0x04  # upper bound is inclusive
16DEF RANGE_LB_INF = 0x08  # lower bound is -infinity
17DEF RANGE_UB_INF = 0x10  # upper bound is +infinity
18
19
20cdef enum _RangeArgumentType:
21    _RANGE_ARGUMENT_INVALID = 0
22    _RANGE_ARGUMENT_TUPLE = 1
23    _RANGE_ARGUMENT_RANGE = 2
24
25
26cdef inline bint _range_has_lbound(uint8_t flags):
27    return not (flags & (RANGE_EMPTY | RANGE_LB_INF))
28
29
30cdef inline bint _range_has_ubound(uint8_t flags):
31    return not (flags & (RANGE_EMPTY | RANGE_UB_INF))
32
33
34cdef inline _RangeArgumentType _range_type(object obj):
35    if cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj):
36        return _RANGE_ARGUMENT_TUPLE
37    elif isinstance(obj, apg_types.Range):
38        return _RANGE_ARGUMENT_RANGE
39    else:
40        return _RANGE_ARGUMENT_INVALID
41
42
43cdef range_encode(ConnectionSettings settings, WriteBuffer buf,
44                  object obj, uint32_t elem_oid,
45                  encode_func_ex encoder, const void *encoder_arg):
46    cdef:
47        ssize_t obj_len
48        uint8_t flags = 0
49        object lower = None
50        object upper = None
51        WriteBuffer bounds_data = WriteBuffer.new()
52        _RangeArgumentType arg_type = _range_type(obj)
53
54    if arg_type == _RANGE_ARGUMENT_INVALID:
55        raise TypeError(
56            'list, tuple or Range object expected (got type {})'.format(
57                type(obj)))
58
59    elif arg_type == _RANGE_ARGUMENT_TUPLE:
60        obj_len = len(obj)
61        if obj_len == 2:
62            lower = obj[0]
63            upper = obj[1]
64
65            if lower is None:
66                flags |= RANGE_LB_INF
67
68            if upper is None:
69                flags |= RANGE_UB_INF
70
71            flags |= RANGE_LB_INC | RANGE_UB_INC
72
73        elif obj_len == 1:
74            lower = obj[0]
75            flags |= RANGE_LB_INC | RANGE_UB_INF
76
77        elif obj_len == 0:
78            flags |= RANGE_EMPTY
79
80        else:
81            raise ValueError(
82                'expected 0, 1 or 2 elements in range (got {})'.format(
83                    obj_len))
84
85    else:
86        if obj.isempty:
87            flags |= RANGE_EMPTY
88        else:
89            lower = obj.lower
90            upper = obj.upper
91
92            if obj.lower_inc:
93                flags |= RANGE_LB_INC
94            elif lower is None:
95                flags |= RANGE_LB_INF
96
97            if obj.upper_inc:
98                flags |= RANGE_UB_INC
99            elif upper is None:
100                flags |= RANGE_UB_INF
101
102    if _range_has_lbound(flags):
103        encoder(settings, bounds_data, lower, encoder_arg)
104
105    if _range_has_ubound(flags):
106        encoder(settings, bounds_data, upper, encoder_arg)
107
108    buf.write_int32(1 + bounds_data.len())
109    buf.write_byte(<int8_t>flags)
110    buf.write_buffer(bounds_data)
111
112
113cdef range_decode(ConnectionSettings settings, FRBuffer *buf,
114                  decode_func_ex decoder, const void *decoder_arg):
115    cdef:
116        uint8_t flags = <uint8_t>frb_read(buf, 1)[0]
117        int32_t bound_len
118        object lower = None
119        object upper = None
120        FRBuffer bound_buf
121
122    if _range_has_lbound(flags):
123        bound_len = hton.unpack_int32(frb_read(buf, 4))
124        if bound_len == -1:
125            lower = None
126        else:
127            frb_slice_from(&bound_buf, buf, bound_len)
128            lower = decoder(settings, &bound_buf, decoder_arg)
129
130    if _range_has_ubound(flags):
131        bound_len = hton.unpack_int32(frb_read(buf, 4))
132        if bound_len == -1:
133            upper = None
134        else:
135            frb_slice_from(&bound_buf, buf, bound_len)
136            upper = decoder(settings, &bound_buf, decoder_arg)
137
138    return apg_types.Range(lower=lower, upper=upper,
139                           lower_inc=(flags & RANGE_LB_INC) != 0,
140                           upper_inc=(flags & RANGE_UB_INC) != 0,
141                           empty=(flags & RANGE_EMPTY) != 0)
142
143
144cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf,
145                       object obj, uint32_t elem_oid,
146                       encode_func_ex encoder, const void *encoder_arg):
147    cdef:
148        WriteBuffer elem_data
149        ssize_t elem_data_len
150        ssize_t elem_count
151
152    if not isinstance(obj, SequenceABC):
153        raise TypeError(
154            'expected a sequence (got type {!r})'.format(type(obj).__name__)
155        )
156
157    elem_data = WriteBuffer.new()
158
159    for elem in obj:
160        range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg)
161
162    elem_count = len(obj)
163    if elem_count > INT32_MAX:
164        raise OverflowError(f'too many elements in multirange value')
165
166    elem_data_len = elem_data.len()
167    if elem_data_len > INT32_MAX - 4:
168        raise OverflowError(
169            f'size of encoded multirange datum exceeds the maximum allowed'
170            f' {INT32_MAX - 4} bytes')
171
172    # Datum length
173    buf.write_int32(4 + <int32_t>elem_data_len)
174    # Number of elements in multirange
175    buf.write_int32(<int32_t>elem_count)
176    buf.write_buffer(elem_data)
177
178
179cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf,
180                       decode_func_ex decoder, const void *decoder_arg):
181    cdef:
182        int32_t nelems = hton.unpack_int32(frb_read(buf, 4))
183        FRBuffer elem_buf
184        int32_t elem_len
185        int i
186        list result
187
188    if nelems == 0:
189        return []
190
191    if nelems < 0:
192        raise exceptions.ProtocolError(
193            'unexpected multirange size value: {}'.format(nelems))
194
195    result = cpython.PyList_New(nelems)
196    for i in range(nelems):
197        elem_len = hton.unpack_int32(frb_read(buf, 4))
198        if elem_len == -1:
199            raise exceptions.ProtocolError(
200                'unexpected NULL element in multirange value')
201        else:
202            frb_slice_from(&elem_buf, buf, elem_len)
203        elem = range_decode(settings, &elem_buf, decoder, decoder_arg)
204        cpython.Py_INCREF(elem)
205        cpython.PyList_SET_ITEM(result, i, elem)
206
207    return result
208