1# oracle/zxjdbc.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""
9.. dialect:: oracle+zxjdbc
10    :name: zxJDBC for Jython
11    :dbapi: zxjdbc
12    :connectstring: oracle+zxjdbc://user:pass@host/dbname
13    :driverurl: http://www.oracle.com/technetwork/database/features/jdbc/index-091264.html
14
15    .. note:: Jython is not supported by current versions of SQLAlchemy.  The
16       zxjdbc dialect should be considered as experimental.
17
18"""  # noqa
19import collections
20import decimal
21import re
22
23from .base import OracleCompiler
24from .base import OracleDialect
25from .base import OracleExecutionContext
26from ... import sql
27from ... import types as sqltypes
28from ... import util
29from ...connectors.zxJDBC import ZxJDBCConnector
30from ...engine import result as _result
31from ...sql import expression
32
33
34SQLException = zxJDBC = None
35
36
37class _ZxJDBCDate(sqltypes.Date):
38    def result_processor(self, dialect, coltype):
39        def process(value):
40            if value is None:
41                return None
42            else:
43                return value.date()
44
45        return process
46
47
48class _ZxJDBCNumeric(sqltypes.Numeric):
49    def result_processor(self, dialect, coltype):
50        # XXX: does the dialect return Decimal or not???
51        # if it does (in all cases), we could use a None processor as well as
52        # the to_float generic processor
53        if self.asdecimal:
54
55            def process(value):
56                if isinstance(value, decimal.Decimal):
57                    return value
58                else:
59                    return decimal.Decimal(str(value))
60
61        else:
62
63            def process(value):
64                if isinstance(value, decimal.Decimal):
65                    return float(value)
66                else:
67                    return value
68
69        return process
70
71
72class OracleCompiler_zxjdbc(OracleCompiler):
73    def returning_clause(self, stmt, returning_cols):
74        self.returning_cols = list(
75            expression._select_iterables(returning_cols)
76        )
77
78        # within_columns_clause=False so that labels (foo AS bar) don't render
79        columns = [
80            self.process(c, within_columns_clause=False)
81            for c in self.returning_cols
82        ]
83
84        if not hasattr(self, "returning_parameters"):
85            self.returning_parameters = []
86
87        binds = []
88        for i, col in enumerate(self.returning_cols):
89            dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(
90                self.dialect.dbapi
91            )
92            self.returning_parameters.append((i + 1, dbtype))
93
94            bindparam = sql.bindparam(
95                "ret_%d" % i, value=ReturningParam(dbtype)
96            )
97            self.binds[bindparam.key] = bindparam
98            binds.append(
99                self.bindparam_string(self._truncate_bindparam(bindparam))
100            )
101
102        return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds)
103
104
105class OracleExecutionContext_zxjdbc(OracleExecutionContext):
106    def pre_exec(self):
107        if hasattr(self.compiled, "returning_parameters"):
108            # prepare a zxJDBC statement so we can grab its underlying
109            # OraclePreparedStatement's getReturnResultSet later
110            self.statement = self.cursor.prepare(self.statement)
111
112    def get_result_proxy(self):
113        if hasattr(self.compiled, "returning_parameters"):
114            rrs = None
115            try:
116                try:
117                    rrs = self.statement.__statement__.getReturnResultSet()
118                    next(rrs)
119                except SQLException as sqle:
120                    msg = "%s [SQLCode: %d]" % (
121                        sqle.getMessage(),
122                        sqle.getErrorCode(),
123                    )
124                    if sqle.getSQLState() is not None:
125                        msg += " [SQLState: %s]" % sqle.getSQLState()
126                    raise zxJDBC.Error(msg)
127                else:
128                    row = tuple(
129                        self.cursor.datahandler.getPyObject(rrs, index, dbtype)
130                        for index, dbtype in self.compiled.returning_parameters
131                    )
132                    return ReturningResultProxy(self, row)
133            finally:
134                if rrs is not None:
135                    try:
136                        rrs.close()
137                    except SQLException:
138                        pass
139                self.statement.close()
140
141        return _result.ResultProxy(self)
142
143    def create_cursor(self):
144        cursor = self._dbapi_connection.cursor()
145        cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
146        return cursor
147
148
149class ReturningResultProxy(_result.FullyBufferedResultProxy):
150
151    """ResultProxy backed by the RETURNING ResultSet results."""
152
153    def __init__(self, context, returning_row):
154        self._returning_row = returning_row
155        super(ReturningResultProxy, self).__init__(context)
156
157    def _cursor_description(self):
158        ret = []
159        for c in self.context.compiled.returning_cols:
160            if hasattr(c, "name"):
161                ret.append((c.name, c.type))
162            else:
163                ret.append((c.anon_label, c.type))
164        return ret
165
166    def _buffer_rows(self):
167        return collections.deque([self._returning_row])
168
169
170class ReturningParam(object):
171
172    """A bindparam value representing a RETURNING parameter.
173
174    Specially handled by OracleReturningDataHandler.
175    """
176
177    def __init__(self, type_):
178        self.type = type_
179
180    def __eq__(self, other):
181        if isinstance(other, ReturningParam):
182            return self.type == other.type
183        return NotImplemented
184
185    def __ne__(self, other):
186        if isinstance(other, ReturningParam):
187            return self.type != other.type
188        return NotImplemented
189
190    def __repr__(self):
191        kls = self.__class__
192        return "<%s.%s object at 0x%x type=%s>" % (
193            kls.__module__,
194            kls.__name__,
195            id(self),
196            self.type,
197        )
198
199
200class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
201    jdbc_db_name = "oracle"
202    jdbc_driver_name = "oracle.jdbc.OracleDriver"
203
204    statement_compiler = OracleCompiler_zxjdbc
205    execution_ctx_cls = OracleExecutionContext_zxjdbc
206
207    colspecs = util.update_copy(
208        OracleDialect.colspecs,
209        {sqltypes.Date: _ZxJDBCDate, sqltypes.Numeric: _ZxJDBCNumeric},
210    )
211
212    def __init__(self, *args, **kwargs):
213        super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs)
214        global SQLException, zxJDBC
215        from java.sql import SQLException
216        from com.ziclix.python.sql import zxJDBC
217        from com.ziclix.python.sql.handler import OracleDataHandler
218
219        class OracleReturningDataHandler(OracleDataHandler):
220            """zxJDBC DataHandler that specially handles ReturningParam."""
221
222            def setJDBCObject(self, statement, index, object_, dbtype=None):
223                if type(object_) is ReturningParam:
224                    statement.registerReturnParameter(index, object_.type)
225                elif dbtype is None:
226                    OracleDataHandler.setJDBCObject(
227                        self, statement, index, object_
228                    )
229                else:
230                    OracleDataHandler.setJDBCObject(
231                        self, statement, index, object_, dbtype
232                    )
233
234        self.DataHandler = OracleReturningDataHandler
235
236    def initialize(self, connection):
237        super(OracleDialect_zxjdbc, self).initialize(connection)
238        self.implicit_returning = connection.connection.driverversion >= "10.2"
239
240    def _create_jdbc_url(self, url):
241        return "jdbc:oracle:thin:@%s:%s:%s" % (
242            url.host,
243            url.port or 1521,
244            url.database,
245        )
246
247    def _get_server_version_info(self, connection):
248        version = re.search(
249            r"Release ([\d\.]+)", connection.connection.dbversion
250        ).group(1)
251        return tuple(int(x) for x in version.split("."))
252
253
254dialect = OracleDialect_zxjdbc
255