1# Copyright (C) 2016-present the asyncpg authors and contributors
2# <see AUTHORS file>
3#
4# This module is part of asyncpg and is released under
5# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
6
7
8from asyncpg import exceptions
9
10
11@cython.final
12cdef class PreparedStatementState:
13
14    def __cinit__(
15        self,
16        str name,
17        str query,
18        BaseProtocol protocol,
19        type record_class,
20        bint ignore_custom_codec
21    ):
22        self.name = name
23        self.query = query
24        self.settings = protocol.settings
25        self.row_desc = self.parameters_desc = None
26        self.args_codecs = self.rows_codecs = None
27        self.args_num = self.cols_num = 0
28        self.cols_desc = None
29        self.closed = False
30        self.refs = 0
31        self.record_class = record_class
32        self.ignore_custom_codec = ignore_custom_codec
33
34    def _get_parameters(self):
35        cdef Codec codec
36
37        result = []
38        for oid in self.parameters_desc:
39            codec = self.settings.get_data_codec(oid)
40            if codec is None:
41                raise exceptions.InternalClientError(
42                    'missing codec information for OID {}'.format(oid))
43            result.append(apg_types.Type(
44                oid, codec.name, codec.kind, codec.schema))
45
46        return tuple(result)
47
48    def _get_attributes(self):
49        cdef Codec codec
50
51        if not self.row_desc:
52            return ()
53
54        result = []
55        for d in self.row_desc:
56            name = d[0]
57            oid = d[3]
58
59            codec = self.settings.get_data_codec(oid)
60            if codec is None:
61                raise exceptions.InternalClientError(
62                    'missing codec information for OID {}'.format(oid))
63
64            name = name.decode(self.settings._encoding)
65
66            result.append(
67                apg_types.Attribute(name,
68                    apg_types.Type(oid, codec.name, codec.kind, codec.schema)))
69
70        return tuple(result)
71
72    def _init_types(self):
73        cdef:
74            Codec codec
75            set missing = set()
76
77        if self.parameters_desc:
78            for p_oid in self.parameters_desc:
79                codec = self.settings.get_data_codec(<uint32_t>p_oid)
80                if codec is None or not codec.has_encoder():
81                    missing.add(p_oid)
82
83        if self.row_desc:
84            for rdesc in self.row_desc:
85                codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
86                if codec is None or not codec.has_decoder():
87                    missing.add(rdesc[3])
88
89        return missing
90
91    cpdef _init_codecs(self):
92        self._ensure_args_encoder()
93        self._ensure_rows_decoder()
94
95    def attach(self):
96        self.refs += 1
97
98    def detach(self):
99        self.refs -= 1
100
101    def mark_closed(self):
102        self.closed = True
103
104    cdef _encode_bind_msg(self, args, int seqno = -1):
105        cdef:
106            int idx
107            WriteBuffer writer
108            Codec codec
109
110        if not cpython.PySequence_Check(args):
111            if seqno >= 0:
112                raise exceptions.DataError(
113                    f'invalid input in executemany() argument sequence '
114                    f'element #{seqno}: expected a sequence, got '
115                    f'{type(args).__name__}'
116                )
117            else:
118                # Non executemany() callers do not pass user input directly,
119                # so bad input is a bug.
120                raise exceptions.InternalClientError(
121                    f'Bind: expected a sequence, got {type(args).__name__}')
122
123        if len(args) > 32767:
124            raise exceptions.InterfaceError(
125                'the number of query arguments cannot exceed 32767')
126
127        writer = WriteBuffer.new()
128
129        num_args_passed = len(args)
130        if self.args_num != num_args_passed:
131            hint = 'Check the query against the passed list of arguments.'
132
133            if self.args_num == 0:
134                # If the server was expecting zero arguments, it is likely
135                # that the user tried to parametrize a statement that does
136                # not support parameters.
137                hint += (r'  Note that parameters are supported only in'
138                         r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
139                         r' statements, and will *not* work in statements '
140                         r' like CREATE VIEW or DECLARE CURSOR.')
141
142            raise exceptions.InterfaceError(
143                'the server expects {x} argument{s} for this query, '
144                '{y} {w} passed'.format(
145                    x=self.args_num, s='s' if self.args_num != 1 else '',
146                    y=num_args_passed,
147                    w='was' if num_args_passed == 1 else 'were'),
148                hint=hint)
149
150        if self.have_text_args:
151            writer.write_int16(self.args_num)
152            for idx in range(self.args_num):
153                codec = <Codec>(self.args_codecs[idx])
154                writer.write_int16(<int16_t>codec.format)
155        else:
156            # All arguments are in binary format
157            writer.write_int32(0x00010001)
158
159        writer.write_int16(self.args_num)
160
161        for idx in range(self.args_num):
162            arg = args[idx]
163            if arg is None:
164                writer.write_int32(-1)
165            else:
166                codec = <Codec>(self.args_codecs[idx])
167                try:
168                    codec.encode(self.settings, writer, arg)
169                except (AssertionError, exceptions.InternalClientError):
170                    # These are internal errors and should raise as-is.
171                    raise
172                except exceptions.InterfaceError as e:
173                    # This is already a descriptive error, but annotate
174                    # with argument name for clarity.
175                    pos = f'${idx + 1}'
176                    if seqno >= 0:
177                        pos = (
178                            f'{pos} in element #{seqno} of'
179                            f' executemany() sequence'
180                        )
181                    raise e.with_msg(
182                        f'query argument {pos}: {e.args[0]}'
183                    ) from None
184                except Exception as e:
185                    # Everything else is assumed to be an encoding error
186                    # due to invalid input.
187                    pos = f'${idx + 1}'
188                    if seqno >= 0:
189                        pos = (
190                            f'{pos} in element #{seqno} of'
191                            f' executemany() sequence'
192                        )
193                    value_repr = repr(arg)
194                    if len(value_repr) > 40:
195                        value_repr = value_repr[:40] + '...'
196
197                    raise exceptions.DataError(
198                        f'invalid input for query argument'
199                        f' {pos}: {value_repr} ({e})'
200                    ) from e
201
202        if self.have_text_cols:
203            writer.write_int16(self.cols_num)
204            for idx in range(self.cols_num):
205                codec = <Codec>(self.rows_codecs[idx])
206                writer.write_int16(<int16_t>codec.format)
207        else:
208            # All columns are in binary format
209            writer.write_int32(0x00010001)
210
211        return writer
212
213    cdef _ensure_rows_decoder(self):
214        cdef:
215            list cols_names
216            object cols_mapping
217            tuple row
218            uint32_t oid
219            Codec codec
220            list codecs
221
222        if self.cols_desc is not None:
223            return
224
225        if self.cols_num == 0:
226            self.cols_desc = record.ApgRecordDesc_New({}, ())
227            return
228
229        cols_mapping = collections.OrderedDict()
230        cols_names = []
231        codecs = []
232        for i from 0 <= i < self.cols_num:
233            row = self.row_desc[i]
234            col_name = row[0].decode(self.settings._encoding)
235            cols_mapping[col_name] = i
236            cols_names.append(col_name)
237            oid = row[3]
238            codec = self.settings.get_data_codec(
239                oid, ignore_custom_codec=self.ignore_custom_codec)
240            if codec is None or not codec.has_decoder():
241                raise exceptions.InternalClientError(
242                    'no decoder for OID {}'.format(oid))
243            if not codec.is_binary():
244                self.have_text_cols = True
245
246            codecs.append(codec)
247
248        self.cols_desc = record.ApgRecordDesc_New(
249            cols_mapping, tuple(cols_names))
250
251        self.rows_codecs = tuple(codecs)
252
253    cdef _ensure_args_encoder(self):
254        cdef:
255            uint32_t p_oid
256            Codec codec
257            list codecs = []
258
259        if self.args_num == 0 or self.args_codecs is not None:
260            return
261
262        for i from 0 <= i < self.args_num:
263            p_oid = self.parameters_desc[i]
264            codec = self.settings.get_data_codec(
265                p_oid, ignore_custom_codec=self.ignore_custom_codec)
266            if codec is None or not codec.has_encoder():
267                raise exceptions.InternalClientError(
268                    'no encoder for OID {}'.format(p_oid))
269            if codec.type not in {}:
270                self.have_text_args = True
271
272            codecs.append(codec)
273
274        self.args_codecs = tuple(codecs)
275
276    cdef _set_row_desc(self, object desc):
277        self.row_desc = _decode_row_desc(desc)
278        self.cols_num = <int16_t>(len(self.row_desc))
279
280    cdef _set_args_desc(self, object desc):
281        self.parameters_desc = _decode_parameters_desc(desc)
282        self.args_num = <int16_t>(len(self.parameters_desc))
283
284    cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
285        cdef:
286            Codec codec
287            int16_t fnum
288            int32_t flen
289            object dec_row
290            tuple rows_codecs = self.rows_codecs
291            ConnectionSettings settings = self.settings
292            int32_t i
293            FRBuffer rbuf
294            ssize_t bl
295
296        frb_init(&rbuf, cbuf, buf_len)
297
298        fnum = hton.unpack_int16(frb_read(&rbuf, 2))
299
300        if fnum != self.cols_num:
301            raise exceptions.ProtocolError(
302                'the number of columns in the result row ({}) is '
303                'different from what was described ({})'.format(
304                    fnum, self.cols_num))
305
306        dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
307        for i in range(fnum):
308            flen = hton.unpack_int32(frb_read(&rbuf, 4))
309
310            if flen == -1:
311                val = None
312            else:
313                # Clamp buffer size to that of the reported field length
314                # to make sure that codecs can rely on read_all() working
315                # properly.
316                bl = frb_get_len(&rbuf)
317                if flen > bl:
318                    frb_check(&rbuf, flen)
319                frb_set_len(&rbuf, flen)
320                codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
321                val = codec.decode(settings, &rbuf)
322                if frb_get_len(&rbuf) != 0:
323                    raise BufferError(
324                        'unexpected trailing {} bytes in buffer'.format(
325                            frb_get_len(&rbuf)))
326                frb_set_len(&rbuf, bl - flen)
327
328            cpython.Py_INCREF(val)
329            record.ApgRecord_SET_ITEM(dec_row, i, val)
330
331        if frb_get_len(&rbuf) != 0:
332            raise BufferError('unexpected trailing {} bytes in buffer'.format(
333                frb_get_len(&rbuf)))
334
335        return dec_row
336
337
338cdef _decode_parameters_desc(object desc):
339    cdef:
340        ReadBuffer reader
341        int16_t nparams
342        uint32_t p_oid
343        list result = []
344
345    reader = ReadBuffer.new_message_parser(desc)
346    nparams = reader.read_int16()
347
348    for i from 0 <= i < nparams:
349        p_oid = <uint32_t>reader.read_int32()
350        result.append(p_oid)
351
352    return result
353
354
355cdef _decode_row_desc(object desc):
356    cdef:
357        ReadBuffer reader
358
359        int16_t nfields
360
361        bytes f_name
362        uint32_t f_table_oid
363        int16_t f_column_num
364        uint32_t f_dt_oid
365        int16_t f_dt_size
366        int32_t f_dt_mod
367        int16_t f_format
368
369        list result
370
371    reader = ReadBuffer.new_message_parser(desc)
372    nfields = reader.read_int16()
373    result = []
374
375    for i from 0 <= i < nfields:
376        f_name = reader.read_null_str()
377        f_table_oid = <uint32_t>reader.read_int32()
378        f_column_num = reader.read_int16()
379        f_dt_oid = <uint32_t>reader.read_int32()
380        f_dt_size = reader.read_int16()
381        f_dt_mod = reader.read_int32()
382        f_format = reader.read_int16()
383
384        result.append(
385            (f_name, f_table_oid, f_column_num, f_dt_oid,
386             f_dt_size, f_dt_mod, f_format))
387
388    return result
389