1# Copyright (C) 2016-present the asyncpg authors and contributors 2# <see AUTHORS file> 3# 4# This module is part of asyncpg and is released under 5# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 6 7 8from asyncpg import exceptions 9 10 11@cython.final 12cdef class PreparedStatementState: 13 14 def __cinit__( 15 self, 16 str name, 17 str query, 18 BaseProtocol protocol, 19 type record_class, 20 bint ignore_custom_codec 21 ): 22 self.name = name 23 self.query = query 24 self.settings = protocol.settings 25 self.row_desc = self.parameters_desc = None 26 self.args_codecs = self.rows_codecs = None 27 self.args_num = self.cols_num = 0 28 self.cols_desc = None 29 self.closed = False 30 self.refs = 0 31 self.record_class = record_class 32 self.ignore_custom_codec = ignore_custom_codec 33 34 def _get_parameters(self): 35 cdef Codec codec 36 37 result = [] 38 for oid in self.parameters_desc: 39 codec = self.settings.get_data_codec(oid) 40 if codec is None: 41 raise exceptions.InternalClientError( 42 'missing codec information for OID {}'.format(oid)) 43 result.append(apg_types.Type( 44 oid, codec.name, codec.kind, codec.schema)) 45 46 return tuple(result) 47 48 def _get_attributes(self): 49 cdef Codec codec 50 51 if not self.row_desc: 52 return () 53 54 result = [] 55 for d in self.row_desc: 56 name = d[0] 57 oid = d[3] 58 59 codec = self.settings.get_data_codec(oid) 60 if codec is None: 61 raise exceptions.InternalClientError( 62 'missing codec information for OID {}'.format(oid)) 63 64 name = name.decode(self.settings._encoding) 65 66 result.append( 67 apg_types.Attribute(name, 68 apg_types.Type(oid, codec.name, codec.kind, codec.schema))) 69 70 return tuple(result) 71 72 def _init_types(self): 73 cdef: 74 Codec codec 75 set missing = set() 76 77 if self.parameters_desc: 78 for p_oid in self.parameters_desc: 79 codec = self.settings.get_data_codec(<uint32_t>p_oid) 80 if codec is None or not codec.has_encoder(): 81 missing.add(p_oid) 82 83 if self.row_desc: 84 for rdesc in self.row_desc: 85 codec = self.settings.get_data_codec(<uint32_t>(rdesc[3])) 86 if codec is None or not codec.has_decoder(): 87 missing.add(rdesc[3]) 88 89 return missing 90 91 cpdef _init_codecs(self): 92 self._ensure_args_encoder() 93 self._ensure_rows_decoder() 94 95 def attach(self): 96 self.refs += 1 97 98 def detach(self): 99 self.refs -= 1 100 101 def mark_closed(self): 102 self.closed = True 103 104 cdef _encode_bind_msg(self, args, int seqno = -1): 105 cdef: 106 int idx 107 WriteBuffer writer 108 Codec codec 109 110 if not cpython.PySequence_Check(args): 111 if seqno >= 0: 112 raise exceptions.DataError( 113 f'invalid input in executemany() argument sequence ' 114 f'element #{seqno}: expected a sequence, got ' 115 f'{type(args).__name__}' 116 ) 117 else: 118 # Non executemany() callers do not pass user input directly, 119 # so bad input is a bug. 120 raise exceptions.InternalClientError( 121 f'Bind: expected a sequence, got {type(args).__name__}') 122 123 if len(args) > 32767: 124 raise exceptions.InterfaceError( 125 'the number of query arguments cannot exceed 32767') 126 127 writer = WriteBuffer.new() 128 129 num_args_passed = len(args) 130 if self.args_num != num_args_passed: 131 hint = 'Check the query against the passed list of arguments.' 132 133 if self.args_num == 0: 134 # If the server was expecting zero arguments, it is likely 135 # that the user tried to parametrize a statement that does 136 # not support parameters. 137 hint += (r' Note that parameters are supported only in' 138 r' SELECT, INSERT, UPDATE, DELETE, and VALUES' 139 r' statements, and will *not* work in statements ' 140 r' like CREATE VIEW or DECLARE CURSOR.') 141 142 raise exceptions.InterfaceError( 143 'the server expects {x} argument{s} for this query, ' 144 '{y} {w} passed'.format( 145 x=self.args_num, s='s' if self.args_num != 1 else '', 146 y=num_args_passed, 147 w='was' if num_args_passed == 1 else 'were'), 148 hint=hint) 149 150 if self.have_text_args: 151 writer.write_int16(self.args_num) 152 for idx in range(self.args_num): 153 codec = <Codec>(self.args_codecs[idx]) 154 writer.write_int16(<int16_t>codec.format) 155 else: 156 # All arguments are in binary format 157 writer.write_int32(0x00010001) 158 159 writer.write_int16(self.args_num) 160 161 for idx in range(self.args_num): 162 arg = args[idx] 163 if arg is None: 164 writer.write_int32(-1) 165 else: 166 codec = <Codec>(self.args_codecs[idx]) 167 try: 168 codec.encode(self.settings, writer, arg) 169 except (AssertionError, exceptions.InternalClientError): 170 # These are internal errors and should raise as-is. 171 raise 172 except exceptions.InterfaceError as e: 173 # This is already a descriptive error, but annotate 174 # with argument name for clarity. 175 pos = f'${idx + 1}' 176 if seqno >= 0: 177 pos = ( 178 f'{pos} in element #{seqno} of' 179 f' executemany() sequence' 180 ) 181 raise e.with_msg( 182 f'query argument {pos}: {e.args[0]}' 183 ) from None 184 except Exception as e: 185 # Everything else is assumed to be an encoding error 186 # due to invalid input. 187 pos = f'${idx + 1}' 188 if seqno >= 0: 189 pos = ( 190 f'{pos} in element #{seqno} of' 191 f' executemany() sequence' 192 ) 193 value_repr = repr(arg) 194 if len(value_repr) > 40: 195 value_repr = value_repr[:40] + '...' 196 197 raise exceptions.DataError( 198 f'invalid input for query argument' 199 f' {pos}: {value_repr} ({e})' 200 ) from e 201 202 if self.have_text_cols: 203 writer.write_int16(self.cols_num) 204 for idx in range(self.cols_num): 205 codec = <Codec>(self.rows_codecs[idx]) 206 writer.write_int16(<int16_t>codec.format) 207 else: 208 # All columns are in binary format 209 writer.write_int32(0x00010001) 210 211 return writer 212 213 cdef _ensure_rows_decoder(self): 214 cdef: 215 list cols_names 216 object cols_mapping 217 tuple row 218 uint32_t oid 219 Codec codec 220 list codecs 221 222 if self.cols_desc is not None: 223 return 224 225 if self.cols_num == 0: 226 self.cols_desc = record.ApgRecordDesc_New({}, ()) 227 return 228 229 cols_mapping = collections.OrderedDict() 230 cols_names = [] 231 codecs = [] 232 for i from 0 <= i < self.cols_num: 233 row = self.row_desc[i] 234 col_name = row[0].decode(self.settings._encoding) 235 cols_mapping[col_name] = i 236 cols_names.append(col_name) 237 oid = row[3] 238 codec = self.settings.get_data_codec( 239 oid, ignore_custom_codec=self.ignore_custom_codec) 240 if codec is None or not codec.has_decoder(): 241 raise exceptions.InternalClientError( 242 'no decoder for OID {}'.format(oid)) 243 if not codec.is_binary(): 244 self.have_text_cols = True 245 246 codecs.append(codec) 247 248 self.cols_desc = record.ApgRecordDesc_New( 249 cols_mapping, tuple(cols_names)) 250 251 self.rows_codecs = tuple(codecs) 252 253 cdef _ensure_args_encoder(self): 254 cdef: 255 uint32_t p_oid 256 Codec codec 257 list codecs = [] 258 259 if self.args_num == 0 or self.args_codecs is not None: 260 return 261 262 for i from 0 <= i < self.args_num: 263 p_oid = self.parameters_desc[i] 264 codec = self.settings.get_data_codec( 265 p_oid, ignore_custom_codec=self.ignore_custom_codec) 266 if codec is None or not codec.has_encoder(): 267 raise exceptions.InternalClientError( 268 'no encoder for OID {}'.format(p_oid)) 269 if codec.type not in {}: 270 self.have_text_args = True 271 272 codecs.append(codec) 273 274 self.args_codecs = tuple(codecs) 275 276 cdef _set_row_desc(self, object desc): 277 self.row_desc = _decode_row_desc(desc) 278 self.cols_num = <int16_t>(len(self.row_desc)) 279 280 cdef _set_args_desc(self, object desc): 281 self.parameters_desc = _decode_parameters_desc(desc) 282 self.args_num = <int16_t>(len(self.parameters_desc)) 283 284 cdef _decode_row(self, const char* cbuf, ssize_t buf_len): 285 cdef: 286 Codec codec 287 int16_t fnum 288 int32_t flen 289 object dec_row 290 tuple rows_codecs = self.rows_codecs 291 ConnectionSettings settings = self.settings 292 int32_t i 293 FRBuffer rbuf 294 ssize_t bl 295 296 frb_init(&rbuf, cbuf, buf_len) 297 298 fnum = hton.unpack_int16(frb_read(&rbuf, 2)) 299 300 if fnum != self.cols_num: 301 raise exceptions.ProtocolError( 302 'the number of columns in the result row ({}) is ' 303 'different from what was described ({})'.format( 304 fnum, self.cols_num)) 305 306 dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum) 307 for i in range(fnum): 308 flen = hton.unpack_int32(frb_read(&rbuf, 4)) 309 310 if flen == -1: 311 val = None 312 else: 313 # Clamp buffer size to that of the reported field length 314 # to make sure that codecs can rely on read_all() working 315 # properly. 316 bl = frb_get_len(&rbuf) 317 if flen > bl: 318 frb_check(&rbuf, flen) 319 frb_set_len(&rbuf, flen) 320 codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i) 321 val = codec.decode(settings, &rbuf) 322 if frb_get_len(&rbuf) != 0: 323 raise BufferError( 324 'unexpected trailing {} bytes in buffer'.format( 325 frb_get_len(&rbuf))) 326 frb_set_len(&rbuf, bl - flen) 327 328 cpython.Py_INCREF(val) 329 record.ApgRecord_SET_ITEM(dec_row, i, val) 330 331 if frb_get_len(&rbuf) != 0: 332 raise BufferError('unexpected trailing {} bytes in buffer'.format( 333 frb_get_len(&rbuf))) 334 335 return dec_row 336 337 338cdef _decode_parameters_desc(object desc): 339 cdef: 340 ReadBuffer reader 341 int16_t nparams 342 uint32_t p_oid 343 list result = [] 344 345 reader = ReadBuffer.new_message_parser(desc) 346 nparams = reader.read_int16() 347 348 for i from 0 <= i < nparams: 349 p_oid = <uint32_t>reader.read_int32() 350 result.append(p_oid) 351 352 return result 353 354 355cdef _decode_row_desc(object desc): 356 cdef: 357 ReadBuffer reader 358 359 int16_t nfields 360 361 bytes f_name 362 uint32_t f_table_oid 363 int16_t f_column_num 364 uint32_t f_dt_oid 365 int16_t f_dt_size 366 int32_t f_dt_mod 367 int16_t f_format 368 369 list result 370 371 reader = ReadBuffer.new_message_parser(desc) 372 nfields = reader.read_int16() 373 result = [] 374 375 for i from 0 <= i < nfields: 376 f_name = reader.read_null_str() 377 f_table_oid = <uint32_t>reader.read_int32() 378 f_column_num = reader.read_int16() 379 f_dt_oid = <uint32_t>reader.read_int32() 380 f_dt_size = reader.read_int16() 381 f_dt_mod = reader.read_int32() 382 f_format = reader.read_int16() 383 384 result.append( 385 (f_name, f_table_oid, f_column_num, f_dt_oid, 386 f_dt_size, f_dt_mod, f_format)) 387 388 return result 389