1############################################################################## 2# Copyright (c) 2009-2018, Hajime Nakagami<nakagami@gmail.com> 3# All rights reserved. 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are met: 7# 8# * Redistributions of source code must retain the above copyright notice, this 9# list of conditions and the following disclaimer. 10# 11# * Redistributions in binary form must reproduce the above copyright notice, 12# this list of conditions and the following disclaimer in the documentation 13# and/or other materials provided with the distribution. 14# 15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25# 26# Python DB-API 2.0 module for Firebird. 27############################################################################## 28from __future__ import print_function 29import sys 30import os 31import socket 32import datetime 33import decimal 34import select 35import hashlib 36 37try: 38 import crypt 39except ImportError: # Not posix 40 crypt = None 41from firebirdsql.fberrmsgs import messages 42from firebirdsql import ( 43 DisconnectByPeer, 44 InternalError, 45 OperationalError, 46 IntegrityError, 47 DataError 48) 49from firebirdsql.consts import * 50from firebirdsql.utils import * 51from firebirdsql import srp 52from firebirdsql import tz_utils 53from firebirdsql import xdrlib 54try: 55 from Crypto.Cipher import ARC4 56except ImportError: 57 from firebirdsql.arc4 import ARC4 58 59DEBUG = False 60 61 62def DEBUG_OUTPUT(*argv): 63 if not DEBUG: 64 return 65 for s in argv: 66 print(s, end=' ', file=sys.stderr) 67 print(file=sys.stderr) 68 69INFO_SQL_SELECT_DESCRIBE_VARS = bs([ 70 isc_info_sql_select, 71 isc_info_sql_describe_vars, 72 isc_info_sql_sqlda_seq, 73 isc_info_sql_type, 74 isc_info_sql_sub_type, 75 isc_info_sql_scale, 76 isc_info_sql_length, 77 isc_info_sql_null_ind, 78 isc_info_sql_field, 79 isc_info_sql_relation, 80 isc_info_sql_owner, 81 isc_info_sql_alias, 82 isc_info_sql_describe_end]) 83 84 85def get_crypt(plain): 86 if crypt is None: 87 return '' 88 return crypt.crypt(plain, '9z')[2:] 89 90 91def convert_date(v): # Convert datetime.date to BLR format data 92 i = v.month + 9 93 jy = v.year + (i // 12) - 1 94 jm = i % 12 95 c = jy // 100 96 jy -= 100 * c 97 j = (146097*c) // 4 + (1461*jy) // 4 + (153*jm+2) // 5 + v.day - 678882 98 return bint_to_bytes(j, 4) 99 100 101def convert_time(v): # Convert datetime.time to BLR format time 102 t = (v.hour*3600 + v.minute*60 + v.second) * 10000 + v.microsecond // 100 103 return bint_to_bytes(t, 4) 104 105 106def convert_timestamp(v): # Convert datetime.datetime to BLR format timestamp 107 return convert_date(v.date()) + convert_time(v.time()) 108 109 110def convert_time_tz(v): # Convert datetime.time to BLR format time_tz 111 try: 112 utc = datetime.timezone.utc 113 except AttributeError: 114 import pytz 115 utc = pytz.utc 116 t = datetime.date.today() 117 native = datetime.datetime( 118 t.year, t.month, t.day, v.hour, v.minute, v.second, v.microsecond 119 ) 120 aware = v.tzinfo.localize(native) 121 v2 = aware.astimezone(utc) 122 123 t = (v2.hour*3600 + v2.minute*60 + v2.second) * 10000 + v2.microsecond // 100 124 r = bint_to_bytes(t, 4) 125 r += bint_to_bytes(tz_utils.get_timezone_id(v.tzinfo.zone), 4) 126 return r 127 128 129def convert_timestamp_tz(v): # Convert datetime.datetime to BLR format timestamp_tz 130 try: 131 utc = datetime.timezone.utc 132 except AttributeError: 133 import pytz 134 utc = pytz.utc 135 native = datetime.datetime( 136 v.year, v.month, v.day, v.hour, v.minute, v.second, v.microsecond 137 ) 138 aware = v.tzinfo.localize(native) 139 v2 = aware.astimezone(utc) 140 141 r = convert_date(v2.date()) + convert_time(v2.time()) 142 r += bint_to_bytes(tz_utils.get_timezone_id(v.tzinfo.zone), 4) 143 return r 144 145 146def wire_operation(fn): 147 if not DEBUG: 148 return fn 149 150 def f(*args, **kwargs): 151 DEBUG_OUTPUT('<--', fn, '-->') 152 r = fn(*args, **kwargs) 153 return r 154 return f 155 156 157class WireProtocol(object): 158 buffer_length = 1024 159 160 op_connect = 1 161 op_exit = 2 162 op_accept = 3 163 op_reject = 4 164 op_protocol = 5 165 op_disconnect = 6 166 op_response = 9 167 op_attach = 19 168 op_create = 20 169 op_detach = 21 170 op_transaction = 29 171 op_commit = 30 172 op_rollback = 31 173 op_open_blob = 35 174 op_get_segment = 36 175 op_put_segment = 37 176 op_close_blob = 39 177 op_info_database = 40 178 op_info_transaction = 42 179 op_batch_segments = 44 180 op_que_events = 48 181 op_cancel_events = 49 182 op_commit_retaining = 50 183 op_event = 52 184 op_connect_request = 53 185 op_aux_connect = 53 186 op_create_blob2 = 57 187 op_allocate_statement = 62 188 op_execute = 63 189 op_exec_immediate = 64 190 op_fetch = 65 191 op_fetch_response = 66 192 op_free_statement = 67 193 op_prepare_statement = 68 194 op_info_sql = 70 195 op_dummy = 71 196 op_execute2 = 76 197 op_sql_response = 78 198 op_drop_database = 81 199 op_service_attach = 82 200 op_service_detach = 83 201 op_service_info = 84 202 op_service_start = 85 203 op_rollback_retaining = 86 204 # FB3 205 op_update_account_info = 87 206 op_authenticate_user = 88 207 op_partial = 89 208 op_trusted_auth = 90 209 op_cancel = 91 210 op_cont_auth = 92 211 op_ping = 93 212 op_accept_data = 94 213 op_abort_aux_connection = 95 214 op_crypt = 96 215 op_crypt_key_callback = 97 216 op_cond_accept = 98 217 218 def __init__(self): 219 self.accept_plugin_name = '' 220 self.auth_data = b'' 221 222 def recv_channel(self, nbytes, word_alignment=False): 223 n = nbytes 224 if word_alignment and (n % 4): 225 n += 4 - nbytes % 4 # 4 bytes word alignment 226 r = bs([]) 227 while n: 228 if (self.timeout is not None and select.select([self.sock._sock], [], [], self.timeout)[0] == []): 229 break 230 b = self.sock.recv(n) 231 if not b: 232 break 233 r += b 234 n -= len(b) 235 if len(r) < nbytes: 236 raise OperationalError('Can not recv() packets') 237 return r[:nbytes] 238 239 def str_to_bytes(self, s): 240 "convert str to bytes" 241 if ((PYTHON_MAJOR_VER == 3 and isinstance(s,str)) or 242 (PYTHON_MAJOR_VER == 2 and type(s) == unicode)): 243 return s.encode(charset_map.get(self.charset, self.charset)) 244 return s 245 246 def bytes_to_str(self, b): 247 "convert bytes array to raw string" 248 if PYTHON_MAJOR_VER == 3: 249 return b.decode(charset_map.get(self.charset, self.charset)) 250 return b 251 252 def bytes_to_ustr(self, b): 253 "convert bytes array to unicode string" 254 return b.decode(charset_map.get(self.charset, self.charset)) 255 256 def _parse_status_vector(self): 257 sql_code = 0 258 gds_codes = set() 259 message = '' 260 n = bytes_to_bint(self.recv_channel(4)) 261 while n != isc_arg_end: 262 if n == isc_arg_gds: 263 gds_code = bytes_to_bint(self.recv_channel(4)) 264 if gds_code: 265 gds_codes.add(gds_code) 266 message += messages.get(gds_code, '@1') 267 num_arg = 0 268 elif n == isc_arg_number: 269 num = bytes_to_bint(self.recv_channel(4)) 270 if gds_code == 335544436: 271 sql_code = num 272 num_arg += 1 273 message = message.replace('@' + str(num_arg), str(num)) 274 elif n == isc_arg_string: 275 nbytes = bytes_to_bint(self.recv_channel(4)) 276 s = self.bytes_to_str(self.recv_channel(nbytes, word_alignment=True)) 277 num_arg += 1 278 message = message.replace('@' + str(num_arg), s) 279 elif n == isc_arg_interpreted: 280 nbytes = bytes_to_bint(self.recv_channel(4)) 281 s = str(self.recv_channel(nbytes, word_alignment=True)) 282 message += s 283 elif n == isc_arg_sql_state: 284 nbytes = bytes_to_bint(self.recv_channel(4)) 285 s = str(self.recv_channel(nbytes, word_alignment=True)) 286 n = bytes_to_bint(self.recv_channel(4)) 287 288 return (gds_codes, sql_code, message) 289 290 def _parse_op_response(self): 291 b = self.recv_channel(16) 292 h = bytes_to_bint(b[0:4]) # Object handle 293 oid = b[4:12] # Object ID 294 buf_len = bytes_to_bint(b[12:]) # buffer length 295 buf = self.recv_channel(buf_len, word_alignment=True) 296 297 (gds_codes, sql_code, message) = self._parse_status_vector() 298 if gds_codes.intersection([ 299 335544838, 335544879, 335544880, 335544466, 335544665, 335544347, 335544558 300 ]): 301 raise IntegrityError(message, gds_codes, sql_code) 302 elif gds_codes.intersection([335544321]): 303 raise DataError(message, gds_codes, sql_code) 304 elif (sql_code or message) and not gds_codes.intersection([335544434]): 305 raise OperationalError(message, gds_codes, sql_code) 306 return (h, oid, buf) 307 308 def _parse_op_event(self): 309 b = self.recv_channel(4096) # too large TODO: read step by step 310 # TODO: parse event name 311 db_handle = bytes_to_bint(b[0:4]) 312 event_id = bytes_to_bint(b[-4:]) 313 314 return (db_handle, event_id, {}) 315 316 def _create_blob(self, trans_handle, b): 317 self._op_create_blob2(trans_handle) 318 (blob_handle, blob_id, buf) = self._op_response() 319 320 i = 0 321 while i < len(b): 322 self._op_put_segment(blob_handle, b[i:i+BLOB_SEGMENT_SIZE]) 323 (h, oid, buf) = self._op_response() 324 i += BLOB_SEGMENT_SIZE 325 326 self._op_close_blob(blob_handle) 327 (h, oid, buf) = self._op_response() 328 return blob_id 329 330 def params_to_blr(self, trans_handle, params): 331 "Convert parameter array to BLR and values format." 332 ln = len(params) * 2 333 blr = bs([5, 2, 4, 0, ln & 255, ln >> 8]) 334 if self.accept_version < PROTOCOL_VERSION13: 335 values = bs([]) 336 else: 337 # start with null indicator bitmap 338 null_indicator = 0 339 for i, p in enumerate(params): 340 if p is None: 341 null_indicator |= (1 << i) 342 n = len(params) // 8 343 if len(params) % 8 != 0: 344 n += 1 345 if n % 4: # padding 346 n += 4 - n % 4 347 null_indicator_bytes = [] 348 for i in range(n): 349 null_indicator_bytes.append(null_indicator & 255) 350 null_indicator >>= 8 351 values = bs(null_indicator_bytes) 352 for p in params: 353 if ( 354 (PYTHON_MAJOR_VER == 2 and type(p) == unicode) or 355 (PYTHON_MAJOR_VER == 3 and type(p) == str) 356 ): 357 p = self.str_to_bytes(p) 358 t = type(p) 359 if p is None: 360 v = bs([]) 361 blr += bs([14, 0, 0]) 362 elif ( 363 (PYTHON_MAJOR_VER == 2 and t == str) or 364 (PYTHON_MAJOR_VER == 3 and t == bytes) 365 ): 366 if len(p) > MAX_CHAR_LENGTH: 367 v = self._create_blob(trans_handle, p) 368 blr += bs([9, 0]) 369 else: 370 v = p 371 nbytes = len(v) 372 pad_length = ((4-nbytes) & 3) 373 v += bs([0]) * pad_length 374 blr += bs([14, nbytes & 255, nbytes >> 8]) 375 elif t == int or (PYTHON_MAJOR_VER == 2 and t == long): 376 if p <= 0x7FFFFFFF and p >= -0x80000000: 377 v = bint_to_bytes(p, 4) 378 blr += bs([8, 0]) # blr_long 379 else: 380 v = bint_to_bytes(p, 8) 381 blr += bs([16, 0]) # blr_int64 382 elif t == float and p == float("inf"): 383 v = b'\x7f\x80\x00\x00' 384 blr += bs([10]) 385 elif t == decimal.Decimal or t == float: 386 if t == float: 387 p = decimal.Decimal(str(p)) 388 (sign, digits, exponent) = p.as_tuple() 389 v = 0 390 ln = len(digits) 391 for i in range(ln): 392 v += digits[i] * (10 ** (ln - i - 1)) 393 if sign: 394 v *= -1 395 v = bint_to_bytes(v, 8) 396 if exponent < 0: 397 exponent += 256 398 blr += bs([16, exponent]) 399 elif t == datetime.date: 400 v = convert_date(p) 401 blr += bs([12]) 402 elif t == datetime.time: 403 if p.tzinfo: 404 v = convert_time_tz(p) 405 blr += bs([28]) 406 else: 407 v = convert_time(p) 408 blr += bs([13]) 409 elif t == datetime.datetime: 410 if p.tzinfo: 411 v = convert_timestamp_tz(p) 412 blr += bs([29]) 413 else: 414 v = convert_timestamp(p) 415 blr += bs([35]) 416 elif t == bool: 417 v = bs([1, 0, 0, 0]) if p else bs([0, 0, 0, 0]) 418 blr += bs([23]) 419 else: # fallback, convert to string 420 p = p.__repr__() 421 if (PYTHON_MAJOR_VER == 3 and isinstance(p, str)) or (PYTHON_MAJOR_VER == 2 and type(p) == unicode): 422 p = self.str_to_bytes(p) 423 v = p 424 nbytes = len(v) 425 pad_length = ((4-nbytes) & 3) 426 v += bs([0]) * pad_length 427 blr += bs([14, nbytes & 255, nbytes >> 8]) 428 blr += bs([7, 0]) 429 values += v 430 if self.accept_version < PROTOCOL_VERSION13: 431 values += bs([0]) * 4 if not p is None else bs([0xff, 0xff, 0xff, 0xff]) 432 blr += bs([255, 76]) # [blr_end, blr_eoc] 433 return blr, values 434 435 def uid(self, auth_plugin_name, wire_crypt): 436 def pack_cnct_param(k, v): 437 if k != CNCT_specific_data: 438 return bs([k] + [len(v)]) + v 439 # specific_data split per 254 bytes 440 b = b'' 441 i = 0 442 while len(v) > 254: 443 b += bs([k, 255, i]) + v[:254] 444 v = v[254:] 445 i += 1 446 b += bs([k, len(v)+1, i]) + v 447 return b 448 449 auth_plugin_list = ('Srp256', 'Srp', 'Legacy_Auth') 450 # get and calculate CNCT_xxxx values 451 if sys.platform == 'win32': 452 user = os.environ['USERNAME'] 453 hostname = os.environ['COMPUTERNAME'] 454 else: 455 user = os.environ.get('USER', '') 456 hostname = socket.gethostname() 457 458 if auth_plugin_name in ('Srp256', 'Srp'): 459 self.client_public_key, self.client_private_key = srp.client_seed() 460 specific_data = bytes_to_hex(srp.long2bytes(self.client_public_key)) 461 elif auth_plugin_name == 'Legacy_Auth': 462 assert crypt, "Legacy_Auth needs crypt module" 463 specific_data = self.str_to_bytes(get_crypt(self.password)) 464 else: 465 raise OperationalError("Unknown auth plugin name '%s'" % (auth_plugin_name,)) 466 self.plugin_name = auth_plugin_name 467 self.plugin_list = b','.join([s.encode('utf-8') for s in auth_plugin_list]) 468 client_crypt = b'\x01\x00\x00\x00' if wire_crypt else b'\x00\x00\x00\x00' 469 470 # set CNCT_xxxx values 471 r = b'' 472 r += pack_cnct_param(CNCT_login, self.str_to_bytes(self.user)) 473 r += pack_cnct_param(CNCT_plugin_name, self.str_to_bytes(self.plugin_name)) 474 r += pack_cnct_param(CNCT_plugin_list, self.plugin_list) 475 r += pack_cnct_param(CNCT_specific_data, specific_data) 476 r += pack_cnct_param(CNCT_client_crypt, client_crypt) 477 478 r += pack_cnct_param(CNCT_user, self.str_to_bytes(user)) 479 r += pack_cnct_param(CNCT_host, self.str_to_bytes(hostname)) 480 r += pack_cnct_param(CNCT_user_verification, b'') 481 return r 482 483 @wire_operation 484 def _op_connect(self, auth_plugin_name, wire_crypt): 485 protocols = [ 486 # PROTOCOL_VERSION, Arch type (Generic=1), min, max, weight 487 '0000000a00000001000000000000000500000002', # 10, 1, 0, 5, 2 488 'ffff800b00000001000000000000000500000004', # 11, 1, 0, 5, 4 489 'ffff800c00000001000000000000000500000006', # 12, 1, 0, 5, 6 490 'ffff800d00000001000000000000000500000008', # 13, 1, 0, 5, 8 491 ] 492 p = xdrlib.Packer() 493 p.pack_int(self.op_connect) 494 p.pack_int(self.op_attach) 495 p.pack_int(3) # CONNECT_VERSION 496 p.pack_int(1) # arch_generic 497 p.pack_bytes(self.str_to_bytes(self.filename if self.filename else '')) 498 499 p.pack_int(len(protocols)) 500 p.pack_bytes(self.uid(auth_plugin_name, wire_crypt)) 501 self.sock.send(p.get_buffer() + hex_to_bytes(''.join(protocols))) 502 503 @wire_operation 504 def _op_create(self, page_size=4096): 505 dpb = bs([1]) 506 s = self.str_to_bytes(self.charset) 507 dpb += bs([isc_dpb_set_db_charset, len(s)]) + s 508 dpb += bs([isc_dpb_lc_ctype, len(s)]) + s 509 s = self.str_to_bytes(self.user) 510 dpb += bs([isc_dpb_user_name, len(s)]) + s 511 if self.accept_version < PROTOCOL_VERSION13: 512 enc_pass = get_crypt(self.password) 513 if self.accept_version == PROTOCOL_VERSION10 or not enc_pass: 514 s = self.str_to_bytes(self.password) 515 dpb += bs([isc_dpb_password, len(s)]) + s 516 else: 517 enc_pass = self.str_to_bytes(enc_pass) 518 dpb += bs([isc_dpb_password_enc, len(enc_pass)]) + enc_pass 519 if self.role: 520 s = self.str_to_bytes(self.role) 521 dpb += bs([isc_dpb_sql_role_name, len(s)]) + s 522 if self.auth_data: 523 s = bytes_to_hex(self.auth_data) 524 dpb += bs([isc_dpb_specific_auth_data, len(s)]) + s 525 if self.timezone: 526 s = self.str_to_bytes(self.timezone) 527 dpb += bs([isc_dpb_session_time_zone, len(s)]) + s 528 dpb += bs([isc_dpb_sql_dialect, 4]) + int_to_bytes(3, 4) 529 dpb += bs([isc_dpb_force_write, 4]) + int_to_bytes(1, 4) 530 dpb += bs([isc_dpb_overwrite, 4]) + int_to_bytes(1, 4) 531 dpb += bs([isc_dpb_page_size, 4]) + int_to_bytes(page_size, 4) 532 p = xdrlib.Packer() 533 p.pack_int(self.op_create) 534 p.pack_int(0) # Database Object ID 535 p.pack_bytes(self.str_to_bytes(self.filename)) 536 p.pack_bytes(dpb) 537 self.sock.send(p.get_buffer()) 538 539 @wire_operation 540 def _op_cont_auth(self, auth_data, auth_plugin_name, auth_plugin_list, keys): 541 p = xdrlib.Packer() 542 p.pack_int(self.op_cont_auth) 543 p.pack_bytes(bytes_to_hex(auth_data)) 544 p.pack_bytes(auth_plugin_name) 545 p.pack_bytes(auth_plugin_list) 546 p.pack_bytes(keys) 547 self.sock.send(p.get_buffer()) 548 549 @wire_operation 550 def _parse_connect_response(self): 551 # want and treat op_accept or op_cond_accept or op_accept_data 552 b = self.recv_channel(4) 553 while bytes_to_bint(b) == self.op_dummy: 554 b = self.recv_channel(4) 555 if bytes_to_bint(b) == self.op_reject: 556 raise OperationalError('Connection is rejected') 557 558 op_code = bytes_to_bint(b) 559 if op_code == self.op_response: 560 return self._parse_op_response() # error occured 561 562 b = self.recv_channel(12) 563 self.accept_version = byte_to_int(b[3]) 564 self.accept_architecture = bytes_to_bint(b[4:8]) 565 self.accept_type = bytes_to_bint(b[8:]) 566 self.lazy_response_count = 0 567 568 if op_code == self.op_cond_accept or op_code == self.op_accept_data: 569 ln = bytes_to_bint(self.recv_channel(4)) 570 data = self.recv_channel(ln, word_alignment=True) 571 572 ln = bytes_to_bint(self.recv_channel(4)) 573 self.accept_plugin_name = self.recv_channel(ln, word_alignment=True) 574 575 is_authenticated = bytes_to_bint(self.recv_channel(4)) 576 ln = bytes_to_bint(self.recv_channel(4)) 577 self.recv_channel(ln, word_alignment=True) # keys 578 579 if is_authenticated == 0: 580 if self.accept_plugin_name in (b'Srp256', b'Srp'): 581 hash_algo = { 582 b'Srp256': hashlib.sha256, 583 b'Srp': hashlib.sha1, 584 }[self.accept_plugin_name] 585 586 user = self.user 587 if len(user) > 2 and user[0] == user[-1] == '"': 588 user = user[1:-1] 589 user = user.replace('""','"') 590 else: 591 user = user.upper() 592 593 if len(data) == 0: 594 # send op_cont_auth 595 self._op_cont_auth( 596 srp.long2bytes(self.client_public_key), 597 self.accept_plugin_name, 598 self.plugin_list, 599 b'' 600 ) 601 # parse op_cont_auth 602 b = self.recv_channel(4) 603 assert bytes_to_bint(b) == self.op_cont_auth 604 ln = bytes_to_bint(self.recv_channel(4)) 605 data = self.recv_channel(ln, word_alignment=True) 606 ln = bytes_to_bint(self.recv_channel(4)) 607 plugin_name = self.recv_channel(ln, word_alignment=True) 608 ln = bytes_to_bint(self.recv_channel(4)) 609 plugin_list = self.recv_channel(ln, word_alignment=True) 610 ln = bytes_to_bint(self.recv_channel(4)) 611 keys = self.recv_channel(ln, word_alignment=True) 612 613 ln = bytes_to_int(data[:2]) 614 server_salt = data[2:ln+2] 615 server_public_key = srp.bytes2long( 616 hex_to_bytes(data[4+ln:])) 617 618 auth_data, session_key = srp.client_proof( 619 self.str_to_bytes(user), 620 self.str_to_bytes(self.password), 621 server_salt, 622 self.client_public_key, 623 server_public_key, 624 self.client_private_key, 625 hash_algo) 626 elif self.accept_plugin_name == b'Legacy_Auth': 627 auth_data = self.str_to_bytes(get_crypt(self.password)) 628 session_key = b'' 629 else: 630 raise OperationalError( 631 'Unknown auth plugin %s' % (self.accept_plugin_name) 632 ) 633 else: 634 auth_data = b'' 635 session_key = b'' 636 637 if op_code == self.op_cond_accept: 638 self._op_cont_auth( 639 auth_data, 640 self.accept_plugin_name, 641 self.plugin_list, 642 b'' 643 ) 644 (h, oid, buf) = self._op_response() 645 646 if self.wire_crypt and session_key: 647 # op_crypt: plugin[Arc4] key[Symmetric] 648 p = xdrlib.Packer() 649 p.pack_int(self.op_crypt) 650 p.pack_bytes(b'Arc4') 651 p.pack_bytes(b'Symmetric') 652 self.sock.send(p.get_buffer()) 653 self.sock.set_translator( 654 ARC4.new(session_key), ARC4.new(session_key)) 655 (h, oid, buf) = self._op_response() 656 else: # use later _op_attach() and _op_create() 657 self.auth_data = auth_data 658 else: 659 assert op_code == self.op_accept 660 661 @wire_operation 662 def _op_attach(self): 663 dpb = bs([isc_dpb_version1]) 664 s = self.str_to_bytes(self.charset) 665 dpb += bs([isc_dpb_lc_ctype, len(s)]) + s 666 s = self.str_to_bytes(self.user) 667 dpb += bs([isc_dpb_user_name, len(s)]) + s 668 if self.accept_version < PROTOCOL_VERSION13: 669 enc_pass = get_crypt(self.password) 670 if self.accept_version == PROTOCOL_VERSION10 or not enc_pass: 671 s = self.str_to_bytes(self.password) 672 dpb += bs([isc_dpb_password, len(s)]) + s 673 else: 674 enc_pass = self.str_to_bytes(enc_pass) 675 dpb += bs([isc_dpb_password_enc, len(enc_pass)]) + enc_pass 676 if self.role: 677 s = self.str_to_bytes(self.role) 678 dpb += bs([isc_dpb_sql_role_name, len(s)]) + s 679 dpb += bs([isc_dpb_process_id, 4]) + int_to_bytes(os.getpid(), 4) 680 s = self.str_to_bytes(sys.argv[0]) 681 dpb += bs([isc_dpb_process_name, len(s)]) + s 682 if self.auth_data: 683 s = bytes_to_hex(self.auth_data) 684 dpb += bs([isc_dpb_specific_auth_data, len(s)]) + s 685 if self.timezone: 686 s = self.str_to_bytes(self.timezone) 687 dpb += bs([isc_dpb_session_time_zone, len(s)]) + s 688 p = xdrlib.Packer() 689 p.pack_int(self.op_attach) 690 p.pack_int(0) # Database Object ID 691 p.pack_bytes(self.str_to_bytes(self.filename)) 692 p.pack_bytes(dpb) 693 self.sock.send(p.get_buffer()) 694 695 @wire_operation 696 def _op_drop_database(self): 697 if self.db_handle is None: 698 raise OperationalError('_op_drop_database() Invalid db handle') 699 p = xdrlib.Packer() 700 p.pack_int(self.op_drop_database) 701 p.pack_int(self.db_handle) 702 self.sock.send(p.get_buffer()) 703 704 @wire_operation 705 def _op_service_attach(self): 706 spb = bs([2, 2]) 707 s = self.str_to_bytes(self.user) 708 spb += bs([isc_spb_user_name, len(s)]) + s 709 if self.accept_version < PROTOCOL_VERSION13: 710 enc_pass = get_crypt(self.password) 711 if self.accept_version == PROTOCOL_VERSION10 or not enc_pass: 712 s = self.str_to_bytes(self.password) 713 spb += bs([isc_dpb_password, len(s)]) + s 714 else: 715 enc_pass = self.str_to_bytes(enc_pass) 716 spb += bs([isc_dpb_password_enc, len(enc_pass)]) + enc_pass 717 if self.auth_data: 718 s = self.str_to_bytes(bytes_to_hex(self.auth_data)) 719 spb += bs([isc_dpb_specific_auth_data, len(s)]) + s 720 spb += bs([isc_spb_dummy_packet_interval, 0x04, 0x78, 0x0a, 0x00, 0x00]) 721 p = xdrlib.Packer() 722 p.pack_int(self.op_service_attach) 723 p.pack_int(0) 724 p.pack_bytes(b'service_mgr') 725 p.pack_bytes(spb) 726 self.sock.send(p.get_buffer()) 727 728 @wire_operation 729 def _op_service_info(self, param, item, buffer_length=512): 730 if self.db_handle is None: 731 raise OperationalError('_op_service_info() Invalid db handle') 732 p = xdrlib.Packer() 733 p.pack_int(self.op_service_info) 734 p.pack_int(self.db_handle) 735 p.pack_int(0) 736 p.pack_bytes(param) 737 p.pack_bytes(item) 738 p.pack_int(buffer_length) 739 self.sock.send(p.get_buffer()) 740 741 @wire_operation 742 def _op_service_start(self, param): 743 if self.db_handle is None: 744 raise OperationalError('_op_service_start() Invalid db handle') 745 p = xdrlib.Packer() 746 p.pack_int(self.op_service_start) 747 p.pack_int(self.db_handle) 748 p.pack_int(0) 749 p.pack_bytes(param) 750 self.sock.send(p.get_buffer()) 751 752 @wire_operation 753 def _op_service_detach(self): 754 if self.db_handle is None: 755 raise OperationalError('_op_service_detach() Invalid db handle') 756 p = xdrlib.Packer() 757 p.pack_int(self.op_service_detach) 758 p.pack_int(self.db_handle) 759 self.sock.send(p.get_buffer()) 760 761 @wire_operation 762 def _op_info_database(self, b): 763 if self.db_handle is None: 764 raise OperationalError('_op_info_database() Invalid db handle') 765 p = xdrlib.Packer() 766 p.pack_int(self.op_info_database) 767 p.pack_int(self.db_handle) 768 p.pack_int(0) 769 p.pack_bytes(b) 770 p.pack_int(self.buffer_length) 771 self.sock.send(p.get_buffer()) 772 773 @wire_operation 774 def _op_transaction(self, tpb): 775 if self.db_handle is None: 776 raise OperationalError('_op_transaction() Invalid db handle') 777 p = xdrlib.Packer() 778 p.pack_int(self.op_transaction) 779 p.pack_int(self.db_handle) 780 p.pack_bytes(tpb) 781 self.sock.send(p.get_buffer()) 782 783 @wire_operation 784 def _op_commit(self, trans_handle): 785 p = xdrlib.Packer() 786 p.pack_int(self.op_commit) 787 p.pack_int(trans_handle) 788 self.sock.send(p.get_buffer()) 789 790 @wire_operation 791 def _op_commit_retaining(self, trans_handle): 792 p = xdrlib.Packer() 793 p.pack_int(self.op_commit_retaining) 794 p.pack_int(trans_handle) 795 self.sock.send(p.get_buffer()) 796 797 @wire_operation 798 def _op_rollback(self, trans_handle): 799 p = xdrlib.Packer() 800 p.pack_int(self.op_rollback) 801 p.pack_int(trans_handle) 802 self.sock.send(p.get_buffer()) 803 804 @wire_operation 805 def _op_rollback_retaining(self, trans_handle): 806 p = xdrlib.Packer() 807 p.pack_int(self.op_rollback_retaining) 808 p.pack_int(trans_handle) 809 self.sock.send(p.get_buffer()) 810 811 @wire_operation 812 def _op_allocate_statement(self): 813 if self.db_handle is None: 814 raise OperationalError('_op_allocate_statement() Invalid db handle') 815 p = xdrlib.Packer() 816 p.pack_int(self.op_allocate_statement) 817 p.pack_int(self.db_handle) 818 self.sock.send(p.get_buffer()) 819 820 @wire_operation 821 def _op_info_transaction(self, trans_handle, b): 822 p = xdrlib.Packer() 823 p.pack_int(self.op_info_transaction) 824 p.pack_int(trans_handle) 825 p.pack_int(0) 826 p.pack_bytes(b) 827 p.pack_int(self.buffer_length) 828 self.sock.send(p.get_buffer()) 829 830 @wire_operation 831 def _op_free_statement(self, stmt_handle, mode): 832 p = xdrlib.Packer() 833 p.pack_int(self.op_free_statement) 834 p.pack_int(stmt_handle) 835 p.pack_int(mode) 836 self.sock.send(p.get_buffer()) 837 838 @wire_operation 839 def _op_prepare_statement(self, stmt_handle, trans_handle, query, option_items=None): 840 if option_items is None: 841 option_items=bs([]) 842 desc_items = option_items + bs([isc_info_sql_stmt_type])+INFO_SQL_SELECT_DESCRIBE_VARS 843 p = xdrlib.Packer() 844 p.pack_int(self.op_prepare_statement) 845 p.pack_int(trans_handle) 846 p.pack_int(stmt_handle) 847 p.pack_int(3) # dialect = 3 848 p.pack_bytes(self.str_to_bytes(query)) 849 p.pack_bytes(desc_items) 850 p.pack_int(self.buffer_length) 851 self.sock.send(p.get_buffer()) 852 853 @wire_operation 854 def _op_info_sql(self, stmt_handle, vars): 855 p = xdrlib.Packer() 856 p.pack_int(self.op_info_sql) 857 p.pack_int(stmt_handle) 858 p.pack_int(0) 859 p.pack_bytes(vars) 860 p.pack_int(self.buffer_length) 861 self.sock.send(p.get_buffer()) 862 863 @wire_operation 864 def _op_execute(self, stmt_handle, trans_handle, params): 865 p = xdrlib.Packer() 866 p.pack_int(self.op_execute) 867 p.pack_int(stmt_handle) 868 p.pack_int(trans_handle) 869 870 if len(params) == 0: 871 p.pack_bytes(bs([])) 872 p.pack_int(0) 873 p.pack_int(0) 874 self.sock.send(p.get_buffer()) 875 else: 876 (blr, values) = self.params_to_blr(trans_handle, params) 877 p.pack_bytes(blr) 878 p.pack_int(0) 879 p.pack_int(1) 880 self.sock.send(p.get_buffer() + values) 881 882 @wire_operation 883 def _op_execute2(self, stmt_handle, trans_handle, params, output_blr): 884 p = xdrlib.Packer() 885 p.pack_int(self.op_execute2) 886 p.pack_int(stmt_handle) 887 p.pack_int(trans_handle) 888 889 if len(params) == 0: 890 values = b'' 891 p.pack_bytes(bs([])) 892 p.pack_int(0) 893 p.pack_int(0) 894 else: 895 (blr, values) = self.params_to_blr(trans_handle, params) 896 p.pack_bytes(blr) 897 p.pack_int(0) 898 p.pack_int(1) 899 900 q = xdrlib.Packer() 901 q.pack_bytes(output_blr) 902 q.pack_int(0) 903 self.sock.send(p.get_buffer() + values + q.get_buffer()) 904 905 @wire_operation 906 def _op_exec_immediate(self, trans_handle, query): 907 if self.db_handle is None: 908 raise OperationalError('_op_exec_immediate() Invalid db handle') 909 desc_items = bs([]) 910 p = xdrlib.Packer() 911 p.pack_int(self.op_exec_immediate) 912 p.pack_int(trans_handle) 913 p.pack_int(self.db_handle) 914 p.pack_int(3) # dialect = 3 915 p.pack_bytes(self.str_to_bytes(query)) 916 p.pack_bytes(desc_items) 917 p.pack_int(self.buffer_length) 918 self.sock.send(p.get_buffer()) 919 920 @wire_operation 921 def _op_fetch(self, stmt_handle, blr): 922 p = xdrlib.Packer() 923 p.pack_int(self.op_fetch) 924 p.pack_int(stmt_handle) 925 p.pack_bytes(blr) 926 p.pack_int(0) 927 p.pack_int(400) 928 self.sock.send(p.get_buffer()) 929 930 @wire_operation 931 def _op_fetch_response(self, stmt_handle, xsqlda): 932 op_code = bytes_to_bint(self.recv_channel(4)) 933 while op_code == self.op_dummy: 934 op_code = bytes_to_bint(self.recv_channel(4)) 935 936 while op_code == self.op_response and self.lazy_response_count: 937 self.lazy_response_count -= 1 938 h, oid, buf = self._parse_op_response() 939 op_code = bytes_to_bint(self.recv_channel(4)) 940 941 if op_code != self.op_fetch_response: 942 if op_code == self.op_response: 943 self._parse_op_response() 944 raise InternalError("op_fetch_response:op_code = %d" % (op_code,)) 945 b = self.recv_channel(8) 946 status = bytes_to_bint(b[:4]) 947 count = bytes_to_bint(b[4:8]) 948 rows = [] 949 while count: 950 r = [None] * len(xsqlda) 951 if self.accept_version < PROTOCOL_VERSION13: 952 for i in range(len(xsqlda)): 953 x = xsqlda[i] 954 if x.io_length() < 0: 955 b = self.recv_channel(4) 956 ln = bytes_to_bint(b) 957 else: 958 ln = x.io_length() 959 raw_value = self.recv_channel(ln, word_alignment=True) 960 if self.recv_channel(4) == bs([0]) * 4: # Not NULL 961 r[i] = x.value(raw_value) 962 else: # PROTOCOL_VERSION13 963 n = len(xsqlda) // 8 964 if len(xsqlda) % 8 != 0: 965 n += 1 966 null_indicator = 0 967 for c in reversed(self.recv_channel(n, word_alignment=True)): 968 null_indicator <<= 8 969 null_indicator += c if PYTHON_MAJOR_VER == 3 else ord(c) 970 for i in range(len(xsqlda)): 971 x = xsqlda[i] 972 if null_indicator & (1 << i): 973 continue 974 if x.io_length() < 0: 975 b = self.recv_channel(4) 976 ln = bytes_to_bint(b) 977 else: 978 ln = x.io_length() 979 raw_value = self.recv_channel(ln, word_alignment=True) 980 r[i] = x.value(raw_value) 981 rows.append(r) 982 b = self.recv_channel(12) 983 op_code = bytes_to_bint(b[:4]) 984 status = bytes_to_bint(b[4:8]) 985 count = bytes_to_bint(b[8:]) 986 return rows, status != 100 987 988 @wire_operation 989 def _op_detach(self): 990 if self.db_handle is None: 991 raise OperationalError('_op_detach() Invalid db handle') 992 p = xdrlib.Packer() 993 p.pack_int(self.op_detach) 994 p.pack_int(self.db_handle) 995 self.sock.send(p.get_buffer()) 996 997 @wire_operation 998 def _op_open_blob(self, blob_id, trans_handle): 999 p = xdrlib.Packer() 1000 p.pack_int(self.op_open_blob) 1001 p.pack_int(trans_handle) 1002 self.sock.send(p.get_buffer() + blob_id) 1003 1004 @wire_operation 1005 def _op_create_blob2(self, trans_handle): 1006 p = xdrlib.Packer() 1007 p.pack_int(self.op_create_blob2) 1008 p.pack_int(0) 1009 p.pack_int(trans_handle) 1010 p.pack_int(0) 1011 p.pack_int(0) 1012 self.sock.send(p.get_buffer()) 1013 1014 @wire_operation 1015 def _op_get_segment(self, blob_handle): 1016 p = xdrlib.Packer() 1017 p.pack_int(self.op_get_segment) 1018 p.pack_int(blob_handle) 1019 p.pack_int(self.buffer_length) 1020 p.pack_int(0) 1021 self.sock.send(p.get_buffer()) 1022 1023 @wire_operation 1024 def _op_put_segment(self, blob_handle, seg_data): 1025 ln = len(seg_data) 1026 p = xdrlib.Packer() 1027 p.pack_int(self.op_put_segment) 1028 p.pack_int(blob_handle) 1029 p.pack_int(ln) 1030 p.pack_int(ln) 1031 pad_length = (4-ln) & 3 1032 self.sock.send(p.get_buffer() + seg_data + bs([0])*pad_length) 1033 1034 @wire_operation 1035 def _op_batch_segments(self, blob_handle, seg_data): 1036 ln = len(seg_data) 1037 p = xdrlib.Packer() 1038 p.pack_int(self.op_batch_segments) 1039 p.pack_int(blob_handle) 1040 p.pack_int(ln + 2) 1041 p.pack_int(ln + 2) 1042 pad_length = ((4-(ln+2)) & 3) 1043 self.sock.send(p.get_buffer() + int_to_bytes(ln, 2) + seg_data + bs([0])*pad_length) 1044 1045 @wire_operation 1046 def _op_close_blob(self, blob_handle): 1047 p = xdrlib.Packer() 1048 p.pack_int(self.op_close_blob) 1049 p.pack_int(blob_handle) 1050 self.sock.send(p.get_buffer()) 1051 1052 @wire_operation 1053 def _op_que_events(self, event_names, event_id): 1054 if self.db_handle is None: 1055 raise OperationalError('_op_que_events() Invalid db handle') 1056 params = bs([1]) 1057 for name, n in event_names.items(): 1058 params += bs([len(name)]) 1059 params += self.str_to_bytes(name) 1060 params += int_to_bytes(n, 4) 1061 p = xdrlib.Packer() 1062 p.pack_int(self.op_que_events) 1063 p.pack_int(self.db_handle) 1064 p.pack_bytes(params) 1065 p.pack_int(0) # ast 1066 p.pack_int(0) # args 1067 p.pack_int(event_id) 1068 self.sock.send(p.get_buffer()) 1069 1070 @wire_operation 1071 def _op_cancel_events(self, event_id): 1072 if self.db_handle is None: 1073 raise OperationalError('_op_cancel_events() Invalid db handle') 1074 p = xdrlib.Packer() 1075 p.pack_int(self.op_cancel_events) 1076 p.pack_int(self.db_handle) 1077 p.pack_int(event_id) 1078 self.sock.send(p.get_buffer()) 1079 1080 @wire_operation 1081 def _op_connect_request(self): 1082 if self.db_handle is None: 1083 raise OperationalError('_op_connect_request() Invalid db handle') 1084 p = xdrlib.Packer() 1085 p.pack_int(self.op_connect_request) 1086 p.pack_int(1) # async 1087 p.pack_int(self.db_handle) 1088 p.pack_int(0) 1089 self.sock.send(p.get_buffer()) 1090 1091 @wire_operation 1092 def _op_response(self): 1093 b = self.recv_channel(4) 1094 while bytes_to_bint(b) == self.op_dummy: 1095 b = self.recv_channel(4) 1096 op_code = bytes_to_bint(b) 1097 while op_code == self.op_response and self.lazy_response_count: 1098 self.lazy_response_count -= 1 1099 h, oid, buf = self._parse_op_response() 1100 b = self.recv_channel(4) 1101 if op_code == self.op_cont_auth: 1102 raise OperationalError('Unauthorized') 1103 elif op_code != self.op_response: 1104 raise InternalError("_op_response:op_code = %d" % (op_code,)) 1105 return self._parse_op_response() 1106 1107 @wire_operation 1108 def _op_event(self): 1109 b = self.recv_channel(4) 1110 while bytes_to_bint(b) == self.op_dummy: 1111 b = self.recv_channel(4) 1112 op_code = bytes_to_bint(b) 1113 if op_code == self.op_response and self.lazy_response_count: 1114 self.lazy_response_count -= 1 1115 self._parse_op_response() 1116 b = self.recv_channel(4) 1117 if op_code == self.op_exit or bytes_to_bint(b) == self.op_exit: 1118 raise DisconnectByPeer 1119 if op_code != self.op_event: 1120 if op_code == self.op_response: 1121 self._parse_op_response() 1122 raise InternalError("_op_event:op_code = %d" % (op_code,)) 1123 return self._parse_op_event() 1124 1125 @wire_operation 1126 def _op_sql_response(self, xsqlda): 1127 b = self.recv_channel(4) 1128 while bytes_to_bint(b) == self.op_dummy: 1129 b = self.recv_channel(4) 1130 op_code = bytes_to_bint(b) 1131 if op_code != self.op_sql_response: 1132 if op_code == self.op_response: 1133 self._parse_op_response() 1134 raise InternalError("_op_sql_response:op_code = %d" % (op_code,)) 1135 1136 b = self.recv_channel(4) 1137 count = bytes_to_bint(b[:4]) 1138 r = [] 1139 if count == 0: 1140 return [] 1141 if self.accept_version < PROTOCOL_VERSION13: 1142 for i in range(len(xsqlda)): 1143 x = xsqlda[i] 1144 if x.io_length() < 0: 1145 b = self.recv_channel(4) 1146 ln = bytes_to_bint(b) 1147 else: 1148 ln = x.io_length() 1149 raw_value = self.recv_channel(ln, word_alignment=True) 1150 if self.recv_channel(4) == bs([0]) * 4: # Not NULL 1151 r.append(x.value(raw_value)) 1152 else: 1153 r.append(None) 1154 else: 1155 n = len(xsqlda) // 8 1156 if len(xsqlda) % 8 != 0: 1157 n += 1 1158 null_indicator = 0 1159 for c in reversed(self.recv_channel(n, word_alignment=True)): 1160 null_indicator <<= 8 1161 null_indicator += c if PYTHON_MAJOR_VER == 3 else ord(c) 1162 for i in range(len(xsqlda)): 1163 x = xsqlda[i] 1164 if null_indicator & (1 << i): 1165 r.append(None) 1166 else: 1167 if x.io_length() < 0: 1168 b = self.recv_channel(4) 1169 ln = bytes_to_bint(b) 1170 else: 1171 ln = x.io_length() 1172 raw_value = self.recv_channel(ln, word_alignment=True) 1173 r.append(x.value(raw_value)) 1174 return r 1175 1176 def _wait_for_event(self, timeout): 1177 event_names = {} 1178 event_id = 0 1179 while True: 1180 b4 = self.recv_channel(4) 1181 if b4 is None: 1182 return None 1183 op_code = bytes_to_bint(b4) 1184 if op_code == self.op_dummy: 1185 pass 1186 elif op_code == self.op_exit or op_code == self.op_disconnect: 1187 break 1188 elif op_code == self.op_event: 1189 bytes_to_int(self.recv_channel(4)) # db_handle 1190 ln = bytes_to_bint(self.recv_channel(4)) 1191 b = self.recv_channel(ln, word_alignment=True) 1192 assert byte_to_int(b[0]) == 1 1193 i = 1 1194 while i < len(b): 1195 ln = byte_to_int(b[i]) 1196 s = self.connection.bytes_to_str(b[i+1:i+1+ln]) 1197 n = bytes_to_int(b[i+1+ln:i+1+ln+4]) 1198 event_names[s] = n 1199 i += ln + 5 1200 self.recv_channel(8) # ignore AST info 1201 1202 event_id = bytes_to_bint(self.recv_channel(4)) 1203 break 1204 else: 1205 raise InternalError("_wait_for_event:op_code = %d" % (op_code,)) 1206 1207 return (event_id, event_names) 1208