1##############################################################################
2# Copyright (c) 2009-2018, Hajime Nakagami<nakagami@gmail.com>
3# All rights reserved.
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are met:
7#
8# * Redistributions of source code must retain the above copyright notice, this
9#   list of conditions and the following disclaimer.
10#
11# * Redistributions in binary form must reproduce the above copyright notice,
12#  this list of conditions and the following disclaimer in the documentation
13#  and/or other materials provided with the distribution.
14#
15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25#
26# Python DB-API 2.0 module for Firebird.
27##############################################################################
28import datetime
29import decimal
30
31from firebirdsql.consts import *
32from firebirdsql.utils import *
33from firebirdsql.wireprotocol import INFO_SQL_SELECT_DESCRIBE_VARS
34from firebirdsql.tz_utils import get_tzinfo
35from firebirdsql import decfloat
36
37
38class XSQLVAR:
39    type_length = {
40        SQL_TYPE_VARYING: -1,
41        SQL_TYPE_SHORT: 4,
42        SQL_TYPE_LONG: 4,
43        SQL_TYPE_FLOAT: 4,
44        SQL_TYPE_TIME: 4,
45        SQL_TYPE_DATE: 4,
46        SQL_TYPE_DOUBLE: 8,
47        SQL_TYPE_TIMESTAMP: 8,
48        SQL_TYPE_BLOB: 8,
49        SQL_TYPE_ARRAY: 8,
50        SQL_TYPE_QUAD: 8,
51        SQL_TYPE_INT64: 8,
52        SQL_TYPE_INT128: 16,
53        SQL_TYPE_TIMESTAMP_TZ: 10,
54        SQL_TYPE_TIME_TZ: 6,
55        SQL_TYPE_DEC64 : 8,
56        SQL_TYPE_DEC128 : 16,
57        SQL_TYPE_DEC_FIXED: 16,
58        SQL_TYPE_BOOLEAN: 1,
59        }
60
61    type_display_length = {
62        SQL_TYPE_VARYING: -1,
63        SQL_TYPE_SHORT: 6,
64        SQL_TYPE_LONG: 11,
65        SQL_TYPE_FLOAT: 17,
66        SQL_TYPE_TIME: 11,
67        SQL_TYPE_DATE: 10,
68        SQL_TYPE_DOUBLE: 17,
69        SQL_TYPE_TIMESTAMP: 22,
70        SQL_TYPE_BLOB: 0,
71        SQL_TYPE_ARRAY: -1,
72        SQL_TYPE_QUAD: 20,
73        SQL_TYPE_INT64: 20,
74        SQL_TYPE_INT128: 20,
75        SQL_TYPE_TIMESTAMP_TZ: 28,
76        SQL_TYPE_TIME_TZ: 17,
77        SQL_TYPE_DEC64: 16,
78        SQL_TYPE_DEC128: 34,
79        SQL_TYPE_DEC_FIXED: 34,
80        SQL_TYPE_BOOLEAN: 5,
81        }
82
83    def __init__(self, bytes_to_str):
84        self.bytes_to_str = bytes_to_str
85        self.sqltype = None
86        self.sqlscale = None
87        self.sqlsubtype = None
88        self.sqllen = None
89        self.null_ok = None
90        self.fieldname = ''
91        self.relname = ''
92        self.ownname = ''
93        self.aliasname = ''
94
95    def io_length(self):
96        sqltype = self.sqltype
97        if sqltype == SQL_TYPE_TEXT:
98            return self.sqllen
99        else:
100            return self.type_length[sqltype]
101
102    def display_length(self):
103        sqltype = self.sqltype
104        if sqltype == SQL_TYPE_TEXT:
105            return self.sqllen
106        else:
107            return self.type_display_length[sqltype]
108
109    def precision(self):
110        return self.display_length()
111
112    def __str__(self):
113        s = ','.join([
114            str(self.sqltype), str(self.sqlscale), str(self.sqlsubtype),
115            str(self.sqllen), str(self.null_ok), self.fieldname,
116            self.relname, self.ownname, self.aliasname,
117        ])
118        return '[' + s + ']'
119
120    def _parse_date(self, raw_value):
121        "Convert raw data to datetime.date"
122        nday = bytes_to_bint(raw_value) + 678882
123        century = (4 * nday - 1) // 146097
124        nday = 4 * nday - 1 - 146097 * century
125        day = nday // 4
126
127        nday = (4 * day + 3) // 1461
128        day = 4 * day + 3 - 1461 * nday
129        day = (day + 4) // 4
130
131        month = (5 * day - 3) // 153
132        day = 5 * day - 3 - 153 * month
133        day = (day + 5) // 5
134        year = 100 * century + nday
135        if month < 10:
136            month += 3
137        else:
138            month -= 9
139            year += 1
140        return year, month, day
141
142    def _parse_time(self, raw_value):
143        "Convert raw data to datetime.time"
144        n = bytes_to_bint(raw_value)
145        s = n // 10000
146        m = s // 60
147        h = m // 60
148        m = m % 60
149        s = s % 60
150        return (h, m, s, (n % 10000) * 100)
151
152    def _parse_time_zone(self, raw_value):
153        return get_tzinfo(bytes_to_uint(raw_value))
154
155    def value(self, raw_value):
156        if self.sqltype == SQL_TYPE_TEXT:
157            return self.bytes_to_str(raw_value).rstrip()
158        elif self.sqltype == SQL_TYPE_VARYING:
159            return self.bytes_to_str(raw_value)
160        elif self.sqltype in (SQL_TYPE_SHORT, SQL_TYPE_LONG, SQL_TYPE_INT64, SQL_TYPE_INT128):
161            n = bytes_to_bint(raw_value)
162            if self.sqlscale:
163                return decimal.Decimal(str(n) + 'e' + str(self.sqlscale))
164            else:
165                return n
166        elif self.sqltype == SQL_TYPE_DATE:
167            yyyy, mm, dd = self._parse_date(raw_value)
168            return datetime.date(yyyy, mm, dd)
169        elif self.sqltype == SQL_TYPE_TIME:
170            h, m, s, ms = self._parse_time(raw_value)
171            return datetime.time(h, m, s, ms)
172        elif self.sqltype == SQL_TYPE_TIMESTAMP:
173            yyyy, mm, dd = self._parse_date(raw_value[:4])
174            h, m, s, ms = self._parse_time(raw_value[4:])
175            return datetime.datetime(yyyy, mm, dd, h, m, s, ms)
176        elif self.sqltype == SQL_TYPE_FLOAT:
177            return struct.unpack('!f', raw_value)[0]
178        elif self.sqltype == SQL_TYPE_DOUBLE:
179            return struct.unpack('!d', raw_value)[0]
180        elif self.sqltype == SQL_TYPE_BOOLEAN:
181            return True if byte_to_int(raw_value[0]) != 0 else False
182        elif self.sqltype == SQL_TYPE_TIMESTAMP_TZ:
183            yyyy, mm, dd = self._parse_date(raw_value[:4])
184            h, m, s, ms = self._parse_time(raw_value[4:8])
185            tz = self._parse_time_zone(raw_value[8:])
186            return datetime.datetime(yyyy, mm, dd, h, m, s, ms, tzinfo=tz)
187        elif self.sqltype == SQL_TYPE_TIME_TZ:
188            h, m, s, ms = self._parse_time(raw_value[:4])
189            tz = self._parse_time_zone(raw_value[4:])
190            return datetime.time(h, m, s, ms, tzinfo=tz)
191        elif self.sqltype == SQL_TYPE_DEC_FIXED:
192            return decfloat.decimal_fixed_to_decimal(raw_value, self.sqlscale)
193        elif self.sqltype == SQL_TYPE_DEC64:
194            return decfloat.decimal64_to_decimal(raw_value)
195        elif self.sqltype == SQL_TYPE_DEC128:
196            return decfloat.decimal128_to_decimal(raw_value)
197        else:
198            return raw_value
199
200
201sqltype2blr = {
202    SQL_TYPE_DOUBLE: [27],
203    SQL_TYPE_FLOAT: [10],
204    SQL_TYPE_D_FLOAT: [11],
205    SQL_TYPE_DATE: [12],
206    SQL_TYPE_TIME: [13],
207    SQL_TYPE_TIMESTAMP: [35],
208    SQL_TYPE_BLOB: [9, 0],
209    SQL_TYPE_ARRAY: [9, 0],
210    SQL_TYPE_BOOLEAN: [23],
211    SQL_TYPE_DEC64: [24],
212    SQL_TYPE_DEC128: [25],
213    SQL_TYPE_TIME_TZ: [28],
214    SQL_TYPE_TIMESTAMP_TZ: [29],
215    }
216
217
218def calc_blr(xsqlda):
219    "Calculate  BLR from XSQLVAR array."
220    ln = len(xsqlda) * 2
221    blr = [5, 2, 4, 0, ln & 255, ln >> 8]
222    for x in xsqlda:
223        sqltype = x.sqltype
224        if sqltype == SQL_TYPE_VARYING:
225            blr += [37, x.sqllen & 255, x.sqllen >> 8]
226        elif sqltype == SQL_TYPE_TEXT:
227            blr += [14, x.sqllen & 255, x.sqllen >> 8]
228        elif sqltype == SQL_TYPE_LONG:
229            blr += [8, x.sqlscale]
230        elif sqltype == SQL_TYPE_SHORT:
231            blr += [7, x.sqlscale]
232        elif sqltype == SQL_TYPE_INT64:
233            blr += [16, x.sqlscale]
234        elif sqltype == SQL_TYPE_INT128:
235            blr += [26, x.sqlscale]
236        elif sqltype == SQL_TYPE_QUAD:
237            blr += [9, x.sqlscale]
238        elif sqltype == SQL_TYPE_DEC_FIXED:
239            blr += [26, x.sqlscale]
240        else:
241            blr += sqltype2blr[sqltype]
242        blr += [7, 0]   # [blr_short, 0]
243    blr += [255, 76]    # [blr_end, blr_eoc]
244
245    # x.sqlscale value shoud be negative, so b convert to range(0, 256)
246    return bs(256 + b if b < 0 else b for b in blr)
247
248
249def parse_select_items(buf, xsqlda, connection):
250    index = 0
251    i = 0
252    item = byte_to_int(buf[i])
253    while item != isc_info_end:
254        if item == isc_info_sql_sqlda_seq:
255            l = bytes_to_int(buf[i+1:i+3])
256            index = bytes_to_int(buf[i+3:i+3+l])
257            xsqlda[index-1] = XSQLVAR(connection.bytes_to_ustr if connection.use_unicode else connection.bytes_to_str)
258            i = i + 3 + l
259        elif item == isc_info_sql_type:
260            l = bytes_to_int(buf[i+1:i+3])
261            xsqlda[index-1].sqltype = bytes_to_int(buf[i+3:i+3+l]) & ~ 1
262            i = i + 3 + l
263        elif item == isc_info_sql_sub_type:
264            l = bytes_to_int(buf[i+1:i+3])
265            xsqlda[index-1].sqlsubtype = bytes_to_int(buf[i+3:i+3+l])
266            i = i + 3 + l
267        elif item == isc_info_sql_scale:
268            l = bytes_to_int(buf[i+1:i+3])
269            xsqlda[index-1].sqlscale = bytes_to_int(buf[i+3:i+3+l])
270            i = i + 3 + l
271        elif item == isc_info_sql_length:
272            l = bytes_to_int(buf[i+1:i+3])
273            xsqlda[index-1].sqllen = bytes_to_int(buf[i+3:i+3+l])
274            i = i + 3 + l
275        elif item == isc_info_sql_null_ind:
276            l = bytes_to_int(buf[i+1:i+3])
277            xsqlda[index-1].null_ok = bytes_to_int(buf[i+3:i+3+l])
278            i = i + 3 + l
279        elif item == isc_info_sql_field:
280            l = bytes_to_int(buf[i+1:i+3])
281            xsqlda[index-1].fieldname = connection.bytes_to_str(buf[i+3:i+3+l])
282            i = i + 3 + l
283        elif item == isc_info_sql_relation:
284            l = bytes_to_int(buf[i+1:i+3])
285            xsqlda[index-1].relname = connection.bytes_to_str(buf[i+3:i+3+l])
286            i = i + 3 + l
287        elif item == isc_info_sql_owner:
288            l = bytes_to_int(buf[i+1:i+3])
289            xsqlda[index-1].ownname = connection.bytes_to_str(buf[i+3:i+3+l])
290            i = i + 3 + l
291        elif item == isc_info_sql_alias:
292            l = bytes_to_int(buf[i+1:i+3])
293            xsqlda[index-1].aliasname = connection.bytes_to_str(buf[i+3:i+3+l])
294            i = i + 3 + l
295        elif item == isc_info_truncated:
296            return index    # return next index
297        elif item == isc_info_sql_describe_end:
298            i = i + 1
299        else:
300            print('\t', item, 'Invalid item [%02x] ! i=%d' % (buf[i], i))
301            i = i + 1
302        item = byte_to_int(buf[i])
303    return -1   # no more info
304
305
306def parse_xsqlda(buf, connection, stmt_handle):
307    xsqlda = []
308    stmt_type = None
309    i = 0
310    while i < len(buf):
311        if buf[i:i+3] == bs([isc_info_sql_stmt_type, 0x04, 0x00]):
312            stmt_type = bytes_to_int(buf[i+3:i+7])
313            i += 7
314        elif buf[i:i+2] == bs([isc_info_sql_select, isc_info_sql_describe_vars]):
315            i += 2
316            l = bytes_to_int(buf[i:i+2])
317            i += 2
318            col_len = bytes_to_int(buf[i:i+l])
319            xsqlda = [None] * col_len
320            next_index = parse_select_items(buf[i+l:], xsqlda, connection)
321            while next_index > 0:   # more describe vars
322                connection._op_info_sql(
323                    stmt_handle,
324                    bs([isc_info_sql_sqlda_start, 2]) + int_to_bytes(next_index, 2) + INFO_SQL_SELECT_DESCRIBE_VARS
325                )
326                (h, oid, buf) = connection._op_response()
327                assert buf[:2] == bs([0x04, 0x07])
328                l = bytes_to_int(buf[2:4])
329                assert bytes_to_int(buf[4:4+l]) == col_len
330                next_index = parse_select_items(buf[4+l:], xsqlda, connection)
331        else:
332            break
333    return stmt_type, xsqlda
334