1##############################################################################
2# Copyright (c) 2009-2018, Hajime Nakagami<nakagami@gmail.com>
3# All rights reserved.
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are met:
7#
8# * Redistributions of source code must retain the above copyright notice, this
9#   list of conditions and the following disclaimer.
10#
11# * Redistributions in binary form must reproduce the above copyright notice,
12#  this list of conditions and the following disclaimer in the documentation
13#  and/or other materials provided with the distribution.
14#
15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25#
26# Python DB-API 2.0 module for Firebird.
27##############################################################################
28from __future__ import print_function
29import sys
30import os
31import socket
32import datetime
33import decimal
34import select
35import hashlib
36
37try:
38    import crypt
39except ImportError:     # Not posix
40    crypt = None
41from firebirdsql.fberrmsgs import messages
42from firebirdsql import (
43    DisconnectByPeer,
44    InternalError,
45    OperationalError,
46    IntegrityError,
47    DataError
48)
49from firebirdsql.consts import *
50from firebirdsql.utils import *
51from firebirdsql import srp
52from firebirdsql import tz_utils
53from firebirdsql import xdrlib
54try:
55    from Crypto.Cipher import ARC4
56except ImportError:
57    from firebirdsql.arc4 import ARC4
58
59DEBUG = False
60
61
62def DEBUG_OUTPUT(*argv):
63    if not DEBUG:
64        return
65    for s in argv:
66        print(s, end=' ', file=sys.stderr)
67    print(file=sys.stderr)
68
69INFO_SQL_SELECT_DESCRIBE_VARS = bs([
70    isc_info_sql_select,
71    isc_info_sql_describe_vars,
72    isc_info_sql_sqlda_seq,
73    isc_info_sql_type,
74    isc_info_sql_sub_type,
75    isc_info_sql_scale,
76    isc_info_sql_length,
77    isc_info_sql_null_ind,
78    isc_info_sql_field,
79    isc_info_sql_relation,
80    isc_info_sql_owner,
81    isc_info_sql_alias,
82    isc_info_sql_describe_end])
83
84
85def get_crypt(plain):
86    if crypt is None:
87        return ''
88    return crypt.crypt(plain, '9z')[2:]
89
90
91def convert_date(v):  # Convert datetime.date to BLR format data
92    i = v.month + 9
93    jy = v.year + (i // 12) - 1
94    jm = i % 12
95    c = jy // 100
96    jy -= 100 * c
97    j = (146097*c) // 4 + (1461*jy) // 4 + (153*jm+2) // 5 + v.day - 678882
98    return bint_to_bytes(j, 4)
99
100
101def convert_time(v):  # Convert datetime.time to BLR format time
102    t = (v.hour*3600 + v.minute*60 + v.second) * 10000 + v.microsecond // 100
103    return bint_to_bytes(t, 4)
104
105
106def convert_timestamp(v):   # Convert datetime.datetime to BLR format timestamp
107    return convert_date(v.date()) + convert_time(v.time())
108
109
110def convert_time_tz(v):  # Convert datetime.time to BLR format time_tz
111    try:
112        utc = datetime.timezone.utc
113    except AttributeError:
114        import pytz
115        utc = pytz.utc
116    t = datetime.date.today()
117    native = datetime.datetime(
118        t.year, t.month, t.day, v.hour, v.minute, v.second, v.microsecond
119    )
120    aware = v.tzinfo.localize(native)
121    v2 = aware.astimezone(utc)
122
123    t = (v2.hour*3600 + v2.minute*60 + v2.second) * 10000 + v2.microsecond // 100
124    r = bint_to_bytes(t, 4)
125    r += bint_to_bytes(tz_utils.get_timezone_id(v.tzinfo.zone), 4)
126    return r
127
128
129def convert_timestamp_tz(v):   # Convert datetime.datetime to BLR format timestamp_tz
130    try:
131        utc = datetime.timezone.utc
132    except AttributeError:
133        import pytz
134        utc = pytz.utc
135    native = datetime.datetime(
136        v.year, v.month, v.day, v.hour, v.minute, v.second, v.microsecond
137    )
138    aware = v.tzinfo.localize(native)
139    v2 = aware.astimezone(utc)
140
141    r = convert_date(v2.date()) + convert_time(v2.time())
142    r += bint_to_bytes(tz_utils.get_timezone_id(v.tzinfo.zone), 4)
143    return r
144
145
146def wire_operation(fn):
147    if not DEBUG:
148        return fn
149
150    def f(*args, **kwargs):
151        DEBUG_OUTPUT('<--', fn, '-->')
152        r = fn(*args, **kwargs)
153        return r
154    return f
155
156
157class WireProtocol(object):
158    buffer_length = 1024
159
160    op_connect = 1
161    op_exit = 2
162    op_accept = 3
163    op_reject = 4
164    op_protocol = 5
165    op_disconnect = 6
166    op_response = 9
167    op_attach = 19
168    op_create = 20
169    op_detach = 21
170    op_transaction = 29
171    op_commit = 30
172    op_rollback = 31
173    op_open_blob = 35
174    op_get_segment = 36
175    op_put_segment = 37
176    op_close_blob = 39
177    op_info_database = 40
178    op_info_transaction = 42
179    op_batch_segments = 44
180    op_que_events = 48
181    op_cancel_events = 49
182    op_commit_retaining = 50
183    op_event = 52
184    op_connect_request = 53
185    op_aux_connect = 53
186    op_create_blob2 = 57
187    op_allocate_statement = 62
188    op_execute = 63
189    op_exec_immediate = 64
190    op_fetch = 65
191    op_fetch_response = 66
192    op_free_statement = 67
193    op_prepare_statement = 68
194    op_info_sql = 70
195    op_dummy = 71
196    op_execute2 = 76
197    op_sql_response = 78
198    op_drop_database = 81
199    op_service_attach = 82
200    op_service_detach = 83
201    op_service_info = 84
202    op_service_start = 85
203    op_rollback_retaining = 86
204    # FB3
205    op_update_account_info = 87
206    op_authenticate_user = 88
207    op_partial = 89
208    op_trusted_auth = 90
209    op_cancel = 91
210    op_cont_auth = 92
211    op_ping = 93
212    op_accept_data = 94
213    op_abort_aux_connection = 95
214    op_crypt = 96
215    op_crypt_key_callback = 97
216    op_cond_accept = 98
217
218    def __init__(self):
219        self.accept_plugin_name = ''
220        self.auth_data = b''
221
222    def recv_channel(self, nbytes, word_alignment=False):
223        n = nbytes
224        if word_alignment and (n % 4):
225            n += 4 - nbytes % 4  # 4 bytes word alignment
226        r = bs([])
227        while n:
228            if (self.timeout is not None and select.select([self.sock._sock], [], [], self.timeout)[0] == []):
229                break
230            b = self.sock.recv(n)
231            if not b:
232                break
233            r += b
234            n -= len(b)
235        if len(r) < nbytes:
236            raise OperationalError('Can not recv() packets')
237        return r[:nbytes]
238
239    def str_to_bytes(self, s):
240        "convert str to bytes"
241        if ((PYTHON_MAJOR_VER == 3  and isinstance(s,str)) or
242                (PYTHON_MAJOR_VER == 2 and type(s) == unicode)):
243            return s.encode(charset_map.get(self.charset, self.charset))
244        return s
245
246    def bytes_to_str(self, b):
247        "convert bytes array to raw string"
248        if PYTHON_MAJOR_VER == 3:
249            return b.decode(charset_map.get(self.charset, self.charset))
250        return b
251
252    def bytes_to_ustr(self, b):
253        "convert bytes array to unicode string"
254        return b.decode(charset_map.get(self.charset, self.charset))
255
256    def _parse_status_vector(self):
257        sql_code = 0
258        gds_codes = set()
259        message = ''
260        n = bytes_to_bint(self.recv_channel(4))
261        while n != isc_arg_end:
262            if n == isc_arg_gds:
263                gds_code = bytes_to_bint(self.recv_channel(4))
264                if gds_code:
265                    gds_codes.add(gds_code)
266                    message += messages.get(gds_code, '@1')
267                    num_arg = 0
268            elif n == isc_arg_number:
269                num = bytes_to_bint(self.recv_channel(4))
270                if gds_code == 335544436:
271                    sql_code = num
272                num_arg += 1
273                message = message.replace('@' + str(num_arg), str(num))
274            elif n == isc_arg_string:
275                nbytes = bytes_to_bint(self.recv_channel(4))
276                s = self.bytes_to_str(self.recv_channel(nbytes, word_alignment=True))
277                num_arg += 1
278                message = message.replace('@' + str(num_arg), s)
279            elif n == isc_arg_interpreted:
280                nbytes = bytes_to_bint(self.recv_channel(4))
281                s = str(self.recv_channel(nbytes, word_alignment=True))
282                message += s
283            elif n == isc_arg_sql_state:
284                nbytes = bytes_to_bint(self.recv_channel(4))
285                s = str(self.recv_channel(nbytes, word_alignment=True))
286            n = bytes_to_bint(self.recv_channel(4))
287
288        return (gds_codes, sql_code, message)
289
290    def _parse_op_response(self):
291        b = self.recv_channel(16)
292        h = bytes_to_bint(b[0:4])         # Object handle
293        oid = b[4:12]                       # Object ID
294        buf_len = bytes_to_bint(b[12:])   # buffer length
295        buf = self.recv_channel(buf_len, word_alignment=True)
296
297        (gds_codes, sql_code, message) = self._parse_status_vector()
298        if gds_codes.intersection([
299            335544838, 335544879, 335544880, 335544466, 335544665, 335544347, 335544558
300        ]):
301            raise IntegrityError(message, gds_codes, sql_code)
302        elif gds_codes.intersection([335544321]):
303            raise DataError(message, gds_codes, sql_code)
304        elif (sql_code or message) and not gds_codes.intersection([335544434]):
305            raise OperationalError(message, gds_codes, sql_code)
306        return (h, oid, buf)
307
308    def _parse_op_event(self):
309        b = self.recv_channel(4096)     # too large TODO: read step by step
310        # TODO: parse event name
311        db_handle = bytes_to_bint(b[0:4])
312        event_id = bytes_to_bint(b[-4:])
313
314        return (db_handle, event_id, {})
315
316    def _create_blob(self, trans_handle, b):
317        self._op_create_blob2(trans_handle)
318        (blob_handle, blob_id, buf) = self._op_response()
319
320        i = 0
321        while i < len(b):
322            self._op_put_segment(blob_handle, b[i:i+BLOB_SEGMENT_SIZE])
323            (h, oid, buf) = self._op_response()
324            i += BLOB_SEGMENT_SIZE
325
326        self._op_close_blob(blob_handle)
327        (h, oid, buf) = self._op_response()
328        return blob_id
329
330    def params_to_blr(self, trans_handle, params):
331        "Convert parameter array to BLR and values format."
332        ln = len(params) * 2
333        blr = bs([5, 2, 4, 0, ln & 255, ln >> 8])
334        if self.accept_version < PROTOCOL_VERSION13:
335            values = bs([])
336        else:
337            # start with null indicator bitmap
338            null_indicator = 0
339            for i, p in enumerate(params):
340                if p is None:
341                    null_indicator |= (1 << i)
342            n = len(params) // 8
343            if len(params) % 8 != 0:
344                n += 1
345            if n % 4:   # padding
346                n += 4 - n % 4
347            null_indicator_bytes = []
348            for i in range(n):
349                null_indicator_bytes.append(null_indicator & 255)
350                null_indicator >>= 8
351            values = bs(null_indicator_bytes)
352        for p in params:
353            if (
354                (PYTHON_MAJOR_VER == 2 and type(p) == unicode) or
355                (PYTHON_MAJOR_VER == 3 and type(p) == str)
356            ):
357                p = self.str_to_bytes(p)
358            t = type(p)
359            if p is None:
360                v = bs([])
361                blr += bs([14, 0, 0])
362            elif (
363                (PYTHON_MAJOR_VER == 2 and t == str) or
364                (PYTHON_MAJOR_VER == 3 and t == bytes)
365            ):
366                if len(p) > MAX_CHAR_LENGTH:
367                    v = self._create_blob(trans_handle, p)
368                    blr += bs([9, 0])
369                else:
370                    v = p
371                    nbytes = len(v)
372                    pad_length = ((4-nbytes) & 3)
373                    v += bs([0]) * pad_length
374                    blr += bs([14, nbytes & 255, nbytes >> 8])
375            elif t == int or (PYTHON_MAJOR_VER == 2 and t == long):
376                if p <= 0x7FFFFFFF and p >= -0x80000000:
377                    v = bint_to_bytes(p, 4)
378                    blr += bs([8, 0])    # blr_long
379                else:
380                    v = bint_to_bytes(p, 8)
381                    blr += bs([16, 0])    # blr_int64
382            elif t == float and p == float("inf"):
383                v = b'\x7f\x80\x00\x00'
384                blr += bs([10])
385            elif t == decimal.Decimal or t == float:
386                if t == float:
387                    p = decimal.Decimal(str(p))
388                (sign, digits, exponent) = p.as_tuple()
389                v = 0
390                ln = len(digits)
391                for i in range(ln):
392                    v += digits[i] * (10 ** (ln - i - 1))
393                if sign:
394                    v *= -1
395                v = bint_to_bytes(v, 8)
396                if exponent < 0:
397                    exponent += 256
398                blr += bs([16, exponent])
399            elif t == datetime.date:
400                v = convert_date(p)
401                blr += bs([12])
402            elif t == datetime.time:
403                if p.tzinfo:
404                    v = convert_time_tz(p)
405                    blr += bs([28])
406                else:
407                    v = convert_time(p)
408                    blr += bs([13])
409            elif t == datetime.datetime:
410                if p.tzinfo:
411                    v = convert_timestamp_tz(p)
412                    blr += bs([29])
413                else:
414                    v = convert_timestamp(p)
415                    blr += bs([35])
416            elif t == bool:
417                v = bs([1, 0, 0, 0]) if p else bs([0, 0, 0, 0])
418                blr += bs([23])
419            else:   # fallback, convert to string
420                p = p.__repr__()
421                if (PYTHON_MAJOR_VER == 3 and isinstance(p, str)) or (PYTHON_MAJOR_VER == 2 and type(p) == unicode):
422                    p = self.str_to_bytes(p)
423                v = p
424                nbytes = len(v)
425                pad_length = ((4-nbytes) & 3)
426                v += bs([0]) * pad_length
427                blr += bs([14, nbytes & 255, nbytes >> 8])
428            blr += bs([7, 0])
429            values += v
430            if self.accept_version < PROTOCOL_VERSION13:
431                values += bs([0]) * 4 if not p is None else bs([0xff, 0xff, 0xff, 0xff])
432        blr += bs([255, 76])    # [blr_end, blr_eoc]
433        return blr, values
434
435    def uid(self, auth_plugin_name, wire_crypt):
436        def pack_cnct_param(k, v):
437            if k != CNCT_specific_data:
438                return bs([k] + [len(v)]) + v
439            # specific_data split per 254 bytes
440            b = b''
441            i = 0
442            while len(v) > 254:
443                b += bs([k, 255, i]) + v[:254]
444                v = v[254:]
445                i += 1
446            b += bs([k, len(v)+1, i]) + v
447            return b
448
449        auth_plugin_list = ('Srp256', 'Srp', 'Legacy_Auth')
450        # get and calculate CNCT_xxxx values
451        if sys.platform == 'win32':
452            user = os.environ['USERNAME']
453            hostname = os.environ['COMPUTERNAME']
454        else:
455            user = os.environ.get('USER', '')
456            hostname = socket.gethostname()
457
458        if auth_plugin_name in ('Srp256', 'Srp'):
459            self.client_public_key, self.client_private_key = srp.client_seed()
460            specific_data = bytes_to_hex(srp.long2bytes(self.client_public_key))
461        elif auth_plugin_name == 'Legacy_Auth':
462            assert crypt, "Legacy_Auth needs crypt module"
463            specific_data = self.str_to_bytes(get_crypt(self.password))
464        else:
465            raise OperationalError("Unknown auth plugin name '%s'" % (auth_plugin_name,))
466        self.plugin_name = auth_plugin_name
467        self.plugin_list = b','.join([s.encode('utf-8') for s in auth_plugin_list])
468        client_crypt = b'\x01\x00\x00\x00' if wire_crypt else b'\x00\x00\x00\x00'
469
470        # set CNCT_xxxx values
471        r = b''
472        r += pack_cnct_param(CNCT_login, self.str_to_bytes(self.user))
473        r += pack_cnct_param(CNCT_plugin_name, self.str_to_bytes(self.plugin_name))
474        r += pack_cnct_param(CNCT_plugin_list, self.plugin_list)
475        r += pack_cnct_param(CNCT_specific_data, specific_data)
476        r += pack_cnct_param(CNCT_client_crypt, client_crypt)
477
478        r += pack_cnct_param(CNCT_user, self.str_to_bytes(user))
479        r += pack_cnct_param(CNCT_host, self.str_to_bytes(hostname))
480        r += pack_cnct_param(CNCT_user_verification, b'')
481        return r
482
483    @wire_operation
484    def _op_connect(self, auth_plugin_name, wire_crypt):
485        protocols = [
486            # PROTOCOL_VERSION, Arch type (Generic=1), min, max, weight
487            '0000000a00000001000000000000000500000002',     # 10, 1, 0, 5, 2
488            'ffff800b00000001000000000000000500000004',     # 11, 1, 0, 5, 4
489            'ffff800c00000001000000000000000500000006',     # 12, 1, 0, 5, 6
490            'ffff800d00000001000000000000000500000008',     # 13, 1, 0, 5, 8
491        ]
492        p = xdrlib.Packer()
493        p.pack_int(self.op_connect)
494        p.pack_int(self.op_attach)
495        p.pack_int(3)   # CONNECT_VERSION
496        p.pack_int(1)   # arch_generic
497        p.pack_bytes(self.str_to_bytes(self.filename if self.filename else ''))
498
499        p.pack_int(len(protocols))
500        p.pack_bytes(self.uid(auth_plugin_name, wire_crypt))
501        self.sock.send(p.get_buffer() + hex_to_bytes(''.join(protocols)))
502
503    @wire_operation
504    def _op_create(self, page_size=4096):
505        dpb = bs([1])
506        s = self.str_to_bytes(self.charset)
507        dpb += bs([isc_dpb_set_db_charset, len(s)]) + s
508        dpb += bs([isc_dpb_lc_ctype, len(s)]) + s
509        s = self.str_to_bytes(self.user)
510        dpb += bs([isc_dpb_user_name, len(s)]) + s
511        if self.accept_version < PROTOCOL_VERSION13:
512            enc_pass = get_crypt(self.password)
513            if self.accept_version == PROTOCOL_VERSION10 or not enc_pass:
514                s = self.str_to_bytes(self.password)
515                dpb += bs([isc_dpb_password, len(s)]) + s
516            else:
517                enc_pass = self.str_to_bytes(enc_pass)
518                dpb += bs([isc_dpb_password_enc, len(enc_pass)]) + enc_pass
519        if self.role:
520            s = self.str_to_bytes(self.role)
521            dpb += bs([isc_dpb_sql_role_name, len(s)]) + s
522        if self.auth_data:
523            s = bytes_to_hex(self.auth_data)
524            dpb += bs([isc_dpb_specific_auth_data, len(s)]) + s
525        if self.timezone:
526            s = self.str_to_bytes(self.timezone)
527            dpb += bs([isc_dpb_session_time_zone, len(s)]) + s
528        dpb += bs([isc_dpb_sql_dialect, 4]) + int_to_bytes(3, 4)
529        dpb += bs([isc_dpb_force_write, 4]) + int_to_bytes(1, 4)
530        dpb += bs([isc_dpb_overwrite, 4]) + int_to_bytes(1, 4)
531        dpb += bs([isc_dpb_page_size, 4]) + int_to_bytes(page_size, 4)
532        p = xdrlib.Packer()
533        p.pack_int(self.op_create)
534        p.pack_int(0)                       # Database Object ID
535        p.pack_bytes(self.str_to_bytes(self.filename))
536        p.pack_bytes(dpb)
537        self.sock.send(p.get_buffer())
538
539    @wire_operation
540    def _op_cont_auth(self, auth_data, auth_plugin_name, auth_plugin_list, keys):
541        p = xdrlib.Packer()
542        p.pack_int(self.op_cont_auth)
543        p.pack_bytes(bytes_to_hex(auth_data))
544        p.pack_bytes(auth_plugin_name)
545        p.pack_bytes(auth_plugin_list)
546        p.pack_bytes(keys)
547        self.sock.send(p.get_buffer())
548
549    @wire_operation
550    def _parse_connect_response(self):
551        # want and treat op_accept or op_cond_accept or op_accept_data
552        b = self.recv_channel(4)
553        while bytes_to_bint(b) == self.op_dummy:
554            b = self.recv_channel(4)
555        if bytes_to_bint(b) == self.op_reject:
556            raise OperationalError('Connection is rejected')
557
558        op_code = bytes_to_bint(b)
559        if op_code == self.op_response:
560            return self._parse_op_response()    # error occured
561
562        b = self.recv_channel(12)
563        self.accept_version = byte_to_int(b[3])
564        self.accept_architecture = bytes_to_bint(b[4:8])
565        self.accept_type = bytes_to_bint(b[8:])
566        self.lazy_response_count = 0
567
568        if op_code == self.op_cond_accept or op_code == self.op_accept_data:
569            ln = bytes_to_bint(self.recv_channel(4))
570            data = self.recv_channel(ln, word_alignment=True)
571
572            ln = bytes_to_bint(self.recv_channel(4))
573            self.accept_plugin_name = self.recv_channel(ln, word_alignment=True)
574
575            is_authenticated = bytes_to_bint(self.recv_channel(4))
576            ln = bytes_to_bint(self.recv_channel(4))
577            self.recv_channel(ln, word_alignment=True)   # keys
578
579            if is_authenticated == 0:
580                if self.accept_plugin_name in (b'Srp256',  b'Srp'):
581                    hash_algo = {
582                        b'Srp256': hashlib.sha256,
583                        b'Srp': hashlib.sha1,
584                    }[self.accept_plugin_name]
585
586                    user = self.user
587                    if len(user) > 2 and user[0] == user[-1] == '"':
588                        user = user[1:-1]
589                        user = user.replace('""','"')
590                    else:
591                        user = user.upper()
592
593                    if len(data) == 0:
594                        # send op_cont_auth
595                        self._op_cont_auth(
596                            srp.long2bytes(self.client_public_key),
597                            self.accept_plugin_name,
598                            self.plugin_list,
599                            b''
600                        )
601                        # parse op_cont_auth
602                        b = self.recv_channel(4)
603                        assert bytes_to_bint(b) == self.op_cont_auth
604                        ln = bytes_to_bint(self.recv_channel(4))
605                        data = self.recv_channel(ln, word_alignment=True)
606                        ln = bytes_to_bint(self.recv_channel(4))
607                        plugin_name = self.recv_channel(ln, word_alignment=True)
608                        ln = bytes_to_bint(self.recv_channel(4))
609                        plugin_list = self.recv_channel(ln, word_alignment=True)
610                        ln = bytes_to_bint(self.recv_channel(4))
611                        keys = self.recv_channel(ln, word_alignment=True)
612
613                    ln = bytes_to_int(data[:2])
614                    server_salt = data[2:ln+2]
615                    server_public_key = srp.bytes2long(
616                        hex_to_bytes(data[4+ln:]))
617
618                    auth_data, session_key = srp.client_proof(
619                        self.str_to_bytes(user),
620                        self.str_to_bytes(self.password),
621                        server_salt,
622                        self.client_public_key,
623                        server_public_key,
624                        self.client_private_key,
625                        hash_algo)
626                elif self.accept_plugin_name == b'Legacy_Auth':
627                    auth_data = self.str_to_bytes(get_crypt(self.password))
628                    session_key = b''
629                else:
630                    raise OperationalError(
631                        'Unknown auth plugin %s' % (self.accept_plugin_name)
632                    )
633            else:
634                auth_data = b''
635                session_key = b''
636
637            if op_code == self.op_cond_accept:
638                self._op_cont_auth(
639                    auth_data,
640                    self.accept_plugin_name,
641                    self.plugin_list,
642                    b''
643                )
644                (h, oid, buf) = self._op_response()
645
646            if self.wire_crypt and session_key:
647                # op_crypt: plugin[Arc4] key[Symmetric]
648                p = xdrlib.Packer()
649                p.pack_int(self.op_crypt)
650                p.pack_bytes(b'Arc4')
651                p.pack_bytes(b'Symmetric')
652                self.sock.send(p.get_buffer())
653                self.sock.set_translator(
654                    ARC4.new(session_key), ARC4.new(session_key))
655                (h, oid, buf) = self._op_response()
656            else:   # use later _op_attach() and _op_create()
657                self.auth_data = auth_data
658        else:
659            assert op_code == self.op_accept
660
661    @wire_operation
662    def _op_attach(self):
663        dpb = bs([isc_dpb_version1])
664        s = self.str_to_bytes(self.charset)
665        dpb += bs([isc_dpb_lc_ctype, len(s)]) + s
666        s = self.str_to_bytes(self.user)
667        dpb += bs([isc_dpb_user_name, len(s)]) + s
668        if self.accept_version < PROTOCOL_VERSION13:
669            enc_pass = get_crypt(self.password)
670            if self.accept_version == PROTOCOL_VERSION10 or not enc_pass:
671                s = self.str_to_bytes(self.password)
672                dpb += bs([isc_dpb_password, len(s)]) + s
673            else:
674                enc_pass = self.str_to_bytes(enc_pass)
675                dpb += bs([isc_dpb_password_enc, len(enc_pass)]) + enc_pass
676        if self.role:
677            s = self.str_to_bytes(self.role)
678            dpb += bs([isc_dpb_sql_role_name, len(s)]) + s
679        dpb += bs([isc_dpb_process_id, 4]) + int_to_bytes(os.getpid(), 4)
680        s = self.str_to_bytes(sys.argv[0])
681        dpb += bs([isc_dpb_process_name, len(s)]) + s
682        if self.auth_data:
683            s = bytes_to_hex(self.auth_data)
684            dpb += bs([isc_dpb_specific_auth_data, len(s)]) + s
685        if self.timezone:
686            s = self.str_to_bytes(self.timezone)
687            dpb += bs([isc_dpb_session_time_zone, len(s)]) + s
688        p = xdrlib.Packer()
689        p.pack_int(self.op_attach)
690        p.pack_int(0)                       # Database Object ID
691        p.pack_bytes(self.str_to_bytes(self.filename))
692        p.pack_bytes(dpb)
693        self.sock.send(p.get_buffer())
694
695    @wire_operation
696    def _op_drop_database(self):
697        if self.db_handle is None:
698            raise OperationalError('_op_drop_database() Invalid db handle')
699        p = xdrlib.Packer()
700        p.pack_int(self.op_drop_database)
701        p.pack_int(self.db_handle)
702        self.sock.send(p.get_buffer())
703
704    @wire_operation
705    def _op_service_attach(self):
706        spb = bs([2, 2])
707        s = self.str_to_bytes(self.user)
708        spb += bs([isc_spb_user_name, len(s)]) + s
709        if self.accept_version < PROTOCOL_VERSION13:
710            enc_pass = get_crypt(self.password)
711            if self.accept_version == PROTOCOL_VERSION10 or not enc_pass:
712                s = self.str_to_bytes(self.password)
713                spb += bs([isc_dpb_password, len(s)]) + s
714            else:
715                enc_pass = self.str_to_bytes(enc_pass)
716                spb += bs([isc_dpb_password_enc, len(enc_pass)]) + enc_pass
717        if self.auth_data:
718            s = self.str_to_bytes(bytes_to_hex(self.auth_data))
719            spb += bs([isc_dpb_specific_auth_data, len(s)]) + s
720        spb += bs([isc_spb_dummy_packet_interval, 0x04, 0x78, 0x0a, 0x00, 0x00])
721        p = xdrlib.Packer()
722        p.pack_int(self.op_service_attach)
723        p.pack_int(0)
724        p.pack_bytes(b'service_mgr')
725        p.pack_bytes(spb)
726        self.sock.send(p.get_buffer())
727
728    @wire_operation
729    def _op_service_info(self, param, item, buffer_length=512):
730        if self.db_handle is None:
731            raise OperationalError('_op_service_info() Invalid db handle')
732        p = xdrlib.Packer()
733        p.pack_int(self.op_service_info)
734        p.pack_int(self.db_handle)
735        p.pack_int(0)
736        p.pack_bytes(param)
737        p.pack_bytes(item)
738        p.pack_int(buffer_length)
739        self.sock.send(p.get_buffer())
740
741    @wire_operation
742    def _op_service_start(self, param):
743        if self.db_handle is None:
744            raise OperationalError('_op_service_start() Invalid db handle')
745        p = xdrlib.Packer()
746        p.pack_int(self.op_service_start)
747        p.pack_int(self.db_handle)
748        p.pack_int(0)
749        p.pack_bytes(param)
750        self.sock.send(p.get_buffer())
751
752    @wire_operation
753    def _op_service_detach(self):
754        if self.db_handle is None:
755            raise OperationalError('_op_service_detach() Invalid db handle')
756        p = xdrlib.Packer()
757        p.pack_int(self.op_service_detach)
758        p.pack_int(self.db_handle)
759        self.sock.send(p.get_buffer())
760
761    @wire_operation
762    def _op_info_database(self, b):
763        if self.db_handle is None:
764            raise OperationalError('_op_info_database() Invalid db handle')
765        p = xdrlib.Packer()
766        p.pack_int(self.op_info_database)
767        p.pack_int(self.db_handle)
768        p.pack_int(0)
769        p.pack_bytes(b)
770        p.pack_int(self.buffer_length)
771        self.sock.send(p.get_buffer())
772
773    @wire_operation
774    def _op_transaction(self, tpb):
775        if self.db_handle is None:
776            raise OperationalError('_op_transaction() Invalid db handle')
777        p = xdrlib.Packer()
778        p.pack_int(self.op_transaction)
779        p.pack_int(self.db_handle)
780        p.pack_bytes(tpb)
781        self.sock.send(p.get_buffer())
782
783    @wire_operation
784    def _op_commit(self, trans_handle):
785        p = xdrlib.Packer()
786        p.pack_int(self.op_commit)
787        p.pack_int(trans_handle)
788        self.sock.send(p.get_buffer())
789
790    @wire_operation
791    def _op_commit_retaining(self, trans_handle):
792        p = xdrlib.Packer()
793        p.pack_int(self.op_commit_retaining)
794        p.pack_int(trans_handle)
795        self.sock.send(p.get_buffer())
796
797    @wire_operation
798    def _op_rollback(self, trans_handle):
799        p = xdrlib.Packer()
800        p.pack_int(self.op_rollback)
801        p.pack_int(trans_handle)
802        self.sock.send(p.get_buffer())
803
804    @wire_operation
805    def _op_rollback_retaining(self, trans_handle):
806        p = xdrlib.Packer()
807        p.pack_int(self.op_rollback_retaining)
808        p.pack_int(trans_handle)
809        self.sock.send(p.get_buffer())
810
811    @wire_operation
812    def _op_allocate_statement(self):
813        if self.db_handle is None:
814            raise OperationalError('_op_allocate_statement() Invalid db handle')
815        p = xdrlib.Packer()
816        p.pack_int(self.op_allocate_statement)
817        p.pack_int(self.db_handle)
818        self.sock.send(p.get_buffer())
819
820    @wire_operation
821    def _op_info_transaction(self, trans_handle, b):
822        p = xdrlib.Packer()
823        p.pack_int(self.op_info_transaction)
824        p.pack_int(trans_handle)
825        p.pack_int(0)
826        p.pack_bytes(b)
827        p.pack_int(self.buffer_length)
828        self.sock.send(p.get_buffer())
829
830    @wire_operation
831    def _op_free_statement(self, stmt_handle, mode):
832        p = xdrlib.Packer()
833        p.pack_int(self.op_free_statement)
834        p.pack_int(stmt_handle)
835        p.pack_int(mode)
836        self.sock.send(p.get_buffer())
837
838    @wire_operation
839    def _op_prepare_statement(self, stmt_handle, trans_handle, query, option_items=None):
840        if option_items is None:
841            option_items=bs([])
842        desc_items = option_items + bs([isc_info_sql_stmt_type])+INFO_SQL_SELECT_DESCRIBE_VARS
843        p = xdrlib.Packer()
844        p.pack_int(self.op_prepare_statement)
845        p.pack_int(trans_handle)
846        p.pack_int(stmt_handle)
847        p.pack_int(3)   # dialect = 3
848        p.pack_bytes(self.str_to_bytes(query))
849        p.pack_bytes(desc_items)
850        p.pack_int(self.buffer_length)
851        self.sock.send(p.get_buffer())
852
853    @wire_operation
854    def _op_info_sql(self, stmt_handle, vars):
855        p = xdrlib.Packer()
856        p.pack_int(self.op_info_sql)
857        p.pack_int(stmt_handle)
858        p.pack_int(0)
859        p.pack_bytes(vars)
860        p.pack_int(self.buffer_length)
861        self.sock.send(p.get_buffer())
862
863    @wire_operation
864    def _op_execute(self, stmt_handle, trans_handle, params):
865        p = xdrlib.Packer()
866        p.pack_int(self.op_execute)
867        p.pack_int(stmt_handle)
868        p.pack_int(trans_handle)
869
870        if len(params) == 0:
871            p.pack_bytes(bs([]))
872            p.pack_int(0)
873            p.pack_int(0)
874            self.sock.send(p.get_buffer())
875        else:
876            (blr, values) = self.params_to_blr(trans_handle, params)
877            p.pack_bytes(blr)
878            p.pack_int(0)
879            p.pack_int(1)
880            self.sock.send(p.get_buffer() + values)
881
882    @wire_operation
883    def _op_execute2(self, stmt_handle, trans_handle, params, output_blr):
884        p = xdrlib.Packer()
885        p.pack_int(self.op_execute2)
886        p.pack_int(stmt_handle)
887        p.pack_int(trans_handle)
888
889        if len(params) == 0:
890            values = b''
891            p.pack_bytes(bs([]))
892            p.pack_int(0)
893            p.pack_int(0)
894        else:
895            (blr, values) = self.params_to_blr(trans_handle, params)
896            p.pack_bytes(blr)
897            p.pack_int(0)
898            p.pack_int(1)
899
900        q = xdrlib.Packer()
901        q.pack_bytes(output_blr)
902        q.pack_int(0)
903        self.sock.send(p.get_buffer() + values + q.get_buffer())
904
905    @wire_operation
906    def _op_exec_immediate(self, trans_handle, query):
907        if self.db_handle is None:
908            raise OperationalError('_op_exec_immediate() Invalid db handle')
909        desc_items = bs([])
910        p = xdrlib.Packer()
911        p.pack_int(self.op_exec_immediate)
912        p.pack_int(trans_handle)
913        p.pack_int(self.db_handle)
914        p.pack_int(3)   # dialect = 3
915        p.pack_bytes(self.str_to_bytes(query))
916        p.pack_bytes(desc_items)
917        p.pack_int(self.buffer_length)
918        self.sock.send(p.get_buffer())
919
920    @wire_operation
921    def _op_fetch(self, stmt_handle, blr):
922        p = xdrlib.Packer()
923        p.pack_int(self.op_fetch)
924        p.pack_int(stmt_handle)
925        p.pack_bytes(blr)
926        p.pack_int(0)
927        p.pack_int(400)
928        self.sock.send(p.get_buffer())
929
930    @wire_operation
931    def _op_fetch_response(self, stmt_handle, xsqlda):
932        op_code = bytes_to_bint(self.recv_channel(4))
933        while op_code == self.op_dummy:
934            op_code = bytes_to_bint(self.recv_channel(4))
935
936        while op_code == self.op_response and self.lazy_response_count:
937            self.lazy_response_count -= 1
938            h, oid, buf = self._parse_op_response()
939            op_code = bytes_to_bint(self.recv_channel(4))
940
941        if op_code != self.op_fetch_response:
942            if op_code == self.op_response:
943                self._parse_op_response()
944            raise InternalError("op_fetch_response:op_code = %d" % (op_code,))
945        b = self.recv_channel(8)
946        status = bytes_to_bint(b[:4])
947        count = bytes_to_bint(b[4:8])
948        rows = []
949        while count:
950            r = [None] * len(xsqlda)
951            if self.accept_version < PROTOCOL_VERSION13:
952                for i in range(len(xsqlda)):
953                    x = xsqlda[i]
954                    if x.io_length() < 0:
955                        b = self.recv_channel(4)
956                        ln = bytes_to_bint(b)
957                    else:
958                        ln = x.io_length()
959                    raw_value = self.recv_channel(ln, word_alignment=True)
960                    if self.recv_channel(4) == bs([0]) * 4:     # Not NULL
961                        r[i] = x.value(raw_value)
962            else:   # PROTOCOL_VERSION13
963                n = len(xsqlda) // 8
964                if len(xsqlda) % 8 != 0:
965                    n += 1
966                null_indicator = 0
967                for c in reversed(self.recv_channel(n, word_alignment=True)):
968                    null_indicator <<= 8
969                    null_indicator += c if PYTHON_MAJOR_VER == 3 else ord(c)
970                for i in range(len(xsqlda)):
971                    x = xsqlda[i]
972                    if null_indicator & (1 << i):
973                        continue
974                    if x.io_length() < 0:
975                        b = self.recv_channel(4)
976                        ln = bytes_to_bint(b)
977                    else:
978                        ln = x.io_length()
979                    raw_value = self.recv_channel(ln, word_alignment=True)
980                    r[i] = x.value(raw_value)
981            rows.append(r)
982            b = self.recv_channel(12)
983            op_code = bytes_to_bint(b[:4])
984            status = bytes_to_bint(b[4:8])
985            count = bytes_to_bint(b[8:])
986        return rows, status != 100
987
988    @wire_operation
989    def _op_detach(self):
990        if self.db_handle is None:
991            raise OperationalError('_op_detach() Invalid db handle')
992        p = xdrlib.Packer()
993        p.pack_int(self.op_detach)
994        p.pack_int(self.db_handle)
995        self.sock.send(p.get_buffer())
996
997    @wire_operation
998    def _op_open_blob(self, blob_id, trans_handle):
999        p = xdrlib.Packer()
1000        p.pack_int(self.op_open_blob)
1001        p.pack_int(trans_handle)
1002        self.sock.send(p.get_buffer() + blob_id)
1003
1004    @wire_operation
1005    def _op_create_blob2(self, trans_handle):
1006        p = xdrlib.Packer()
1007        p.pack_int(self.op_create_blob2)
1008        p.pack_int(0)
1009        p.pack_int(trans_handle)
1010        p.pack_int(0)
1011        p.pack_int(0)
1012        self.sock.send(p.get_buffer())
1013
1014    @wire_operation
1015    def _op_get_segment(self, blob_handle):
1016        p = xdrlib.Packer()
1017        p.pack_int(self.op_get_segment)
1018        p.pack_int(blob_handle)
1019        p.pack_int(self.buffer_length)
1020        p.pack_int(0)
1021        self.sock.send(p.get_buffer())
1022
1023    @wire_operation
1024    def _op_put_segment(self, blob_handle, seg_data):
1025        ln = len(seg_data)
1026        p = xdrlib.Packer()
1027        p.pack_int(self.op_put_segment)
1028        p.pack_int(blob_handle)
1029        p.pack_int(ln)
1030        p.pack_int(ln)
1031        pad_length = (4-ln) & 3
1032        self.sock.send(p.get_buffer() + seg_data + bs([0])*pad_length)
1033
1034    @wire_operation
1035    def _op_batch_segments(self, blob_handle, seg_data):
1036        ln = len(seg_data)
1037        p = xdrlib.Packer()
1038        p.pack_int(self.op_batch_segments)
1039        p.pack_int(blob_handle)
1040        p.pack_int(ln + 2)
1041        p.pack_int(ln + 2)
1042        pad_length = ((4-(ln+2)) & 3)
1043        self.sock.send(p.get_buffer() + int_to_bytes(ln, 2) + seg_data + bs([0])*pad_length)
1044
1045    @wire_operation
1046    def _op_close_blob(self, blob_handle):
1047        p = xdrlib.Packer()
1048        p.pack_int(self.op_close_blob)
1049        p.pack_int(blob_handle)
1050        self.sock.send(p.get_buffer())
1051
1052    @wire_operation
1053    def _op_que_events(self, event_names, event_id):
1054        if self.db_handle is None:
1055            raise OperationalError('_op_que_events() Invalid db handle')
1056        params = bs([1])
1057        for name, n in event_names.items():
1058            params += bs([len(name)])
1059            params += self.str_to_bytes(name)
1060            params += int_to_bytes(n, 4)
1061        p = xdrlib.Packer()
1062        p.pack_int(self.op_que_events)
1063        p.pack_int(self.db_handle)
1064        p.pack_bytes(params)
1065        p.pack_int(0)    # ast
1066        p.pack_int(0)    # args
1067        p.pack_int(event_id)
1068        self.sock.send(p.get_buffer())
1069
1070    @wire_operation
1071    def _op_cancel_events(self, event_id):
1072        if self.db_handle is None:
1073            raise OperationalError('_op_cancel_events() Invalid db handle')
1074        p = xdrlib.Packer()
1075        p.pack_int(self.op_cancel_events)
1076        p.pack_int(self.db_handle)
1077        p.pack_int(event_id)
1078        self.sock.send(p.get_buffer())
1079
1080    @wire_operation
1081    def _op_connect_request(self):
1082        if self.db_handle is None:
1083            raise OperationalError('_op_connect_request() Invalid db handle')
1084        p = xdrlib.Packer()
1085        p.pack_int(self.op_connect_request)
1086        p.pack_int(1)    # async
1087        p.pack_int(self.db_handle)
1088        p.pack_int(0)
1089        self.sock.send(p.get_buffer())
1090
1091    @wire_operation
1092    def _op_response(self):
1093        b = self.recv_channel(4)
1094        while bytes_to_bint(b) == self.op_dummy:
1095            b = self.recv_channel(4)
1096        op_code = bytes_to_bint(b)
1097        while op_code == self.op_response and self.lazy_response_count:
1098            self.lazy_response_count -= 1
1099            h, oid, buf = self._parse_op_response()
1100            b = self.recv_channel(4)
1101        if op_code == self.op_cont_auth:
1102            raise OperationalError('Unauthorized')
1103        elif op_code != self.op_response:
1104            raise InternalError("_op_response:op_code = %d" % (op_code,))
1105        return self._parse_op_response()
1106
1107    @wire_operation
1108    def _op_event(self):
1109        b = self.recv_channel(4)
1110        while bytes_to_bint(b) == self.op_dummy:
1111            b = self.recv_channel(4)
1112        op_code = bytes_to_bint(b)
1113        if op_code == self.op_response and self.lazy_response_count:
1114            self.lazy_response_count -= 1
1115            self._parse_op_response()
1116            b = self.recv_channel(4)
1117        if op_code == self.op_exit or bytes_to_bint(b) == self.op_exit:
1118            raise DisconnectByPeer
1119        if op_code != self.op_event:
1120            if op_code == self.op_response:
1121                self._parse_op_response()
1122            raise InternalError("_op_event:op_code = %d" % (op_code,))
1123        return self._parse_op_event()
1124
1125    @wire_operation
1126    def _op_sql_response(self, xsqlda):
1127        b = self.recv_channel(4)
1128        while bytes_to_bint(b) == self.op_dummy:
1129            b = self.recv_channel(4)
1130        op_code = bytes_to_bint(b)
1131        if op_code != self.op_sql_response:
1132            if op_code == self.op_response:
1133                self._parse_op_response()
1134            raise InternalError("_op_sql_response:op_code = %d" % (op_code,))
1135
1136        b = self.recv_channel(4)
1137        count = bytes_to_bint(b[:4])
1138        r = []
1139        if count == 0:
1140            return []
1141        if self.accept_version < PROTOCOL_VERSION13:
1142            for i in range(len(xsqlda)):
1143                x = xsqlda[i]
1144                if x.io_length() < 0:
1145                    b = self.recv_channel(4)
1146                    ln = bytes_to_bint(b)
1147                else:
1148                    ln = x.io_length()
1149                raw_value = self.recv_channel(ln, word_alignment=True)
1150                if self.recv_channel(4) == bs([0]) * 4:     # Not NULL
1151                    r.append(x.value(raw_value))
1152                else:
1153                    r.append(None)
1154        else:
1155            n = len(xsqlda) // 8
1156            if len(xsqlda) % 8 != 0:
1157                n += 1
1158            null_indicator = 0
1159            for c in reversed(self.recv_channel(n, word_alignment=True)):
1160                null_indicator <<= 8
1161                null_indicator += c if PYTHON_MAJOR_VER == 3 else ord(c)
1162            for i in range(len(xsqlda)):
1163                x = xsqlda[i]
1164                if null_indicator & (1 << i):
1165                    r.append(None)
1166                else:
1167                    if x.io_length() < 0:
1168                        b = self.recv_channel(4)
1169                        ln = bytes_to_bint(b)
1170                    else:
1171                        ln = x.io_length()
1172                    raw_value = self.recv_channel(ln, word_alignment=True)
1173                    r.append(x.value(raw_value))
1174        return r
1175
1176    def _wait_for_event(self, timeout):
1177        event_names = {}
1178        event_id = 0
1179        while True:
1180            b4 = self.recv_channel(4)
1181            if b4 is None:
1182                return None
1183            op_code = bytes_to_bint(b4)
1184            if op_code == self.op_dummy:
1185                pass
1186            elif op_code == self.op_exit or op_code == self.op_disconnect:
1187                break
1188            elif op_code == self.op_event:
1189                bytes_to_int(self.recv_channel(4))  # db_handle
1190                ln = bytes_to_bint(self.recv_channel(4))
1191                b = self.recv_channel(ln, word_alignment=True)
1192                assert byte_to_int(b[0]) == 1
1193                i = 1
1194                while i < len(b):
1195                    ln = byte_to_int(b[i])
1196                    s = self.connection.bytes_to_str(b[i+1:i+1+ln])
1197                    n = bytes_to_int(b[i+1+ln:i+1+ln+4])
1198                    event_names[s] = n
1199                    i += ln + 5
1200                self.recv_channel(8)  # ignore AST info
1201
1202                event_id = bytes_to_bint(self.recv_channel(4))
1203                break
1204            else:
1205                raise InternalError("_wait_for_event:op_code = %d" % (op_code,))
1206
1207        return (event_id, event_names)
1208