1# $Id: 85799ff3b36841e5fe97bf09dfc658849b83a71f $
2
3"""
4Oracle extended database driver.
5"""
6
7__docformat__ = "restructuredtext en"
8
9# ---------------------------------------------------------------------------
10# Imports
11# ---------------------------------------------------------------------------
12
13import os
14import sys
15
16from grizzled.db.base import (DBDriver, Error, Warning,
17                              TableMetadata, IndexMetadata, RDBMSMetadata)
18
19# ---------------------------------------------------------------------------
20# Constants
21# ---------------------------------------------------------------------------
22
23VENDOR = 'Oracle Corporation'
24PRODUCT = 'Oracle'
25
26# ---------------------------------------------------------------------------
27# Classes
28# ---------------------------------------------------------------------------
29
30class OracleDriver(DBDriver):
31    """DB Driver for Oracle, using the cx_Oracle DB API module."""
32
33    def get_import(self):
34        import cx_Oracle
35        return cx_Oracle
36
37    def get_display_name(self):
38        return "Oracle"
39
40    def do_connect(self,
41                   host='localhost',
42                   port=None,
43                   user='',
44                   password='',
45                   database='default'):
46        dbi = self.get_import()
47        return dbi.connect('%s/%s@%s' % (user, password, database))
48
49    def get_tables(self, cursor):
50        cursor.execute('select lower(table_name) from all_tables')
51        table_names = []
52        rs = cursor.fetchone()
53        while rs is not None:
54            name = rs[0]
55            # Skip tables with "$" in them.
56            if name.find('$') < 0:
57                table_names.append(name)
58            rs = cursor.fetchone()
59
60        return table_names
61
62    def get_rdbms_metadata(self, cursor):
63        cursor.execute("SELECT banner FROM v$version WHERE "
64                       "banner LIKE 'Oracle%'")
65        rs = cursor.fetchone()
66        if rs is None:
67            result = RDBMSMetadata(VENDOR, PRODUCT, 'unknown')
68        else:
69            result = RDBMSMetadata(VENDOR, PRODUCT, rs[0])
70
71        return result
72
73    def get_table_metadata(self, table, cursor):
74        self._ensure_valid_table(cursor, table)
75        cursor.execute("select column_name, data_type, data_length, "
76                       "data_precision, data_scale, nullable, "
77                       "char_col_decl_length from all_tab_columns "
78                       "where lower(table_name) = '%s'" % table.lower())
79        results = []
80        rs = cursor.fetchone()
81        while rs:
82            column = rs[0]
83            coltype = rs[1]
84            data_length = rs[2]
85            precision = rs[3]
86            scale = rs[4]
87            nullable = (rs[5] == 'Y')
88            declared_char_length = rs[6]
89
90            if declared_char_length:
91                length = declared_char_length
92            else:
93                length = data_length
94
95            results += [TableMetadata(column,
96                                      coltype,
97                                      length,
98                                      precision,
99                                      scale,
100                                      nullable)]
101            rs = cursor.fetchone()
102
103        return results
104
105    def get_index_metadata(self, table, cursor):
106        self._ensure_valid_table(cursor, table)
107        # First, issue a query to get the list of indexes and some
108        # descriptive information.
109        cursor.execute("select index_name, index_type, uniqueness, "
110                       "max_extents,temporary from all_indexes where "
111                       "lower(table_name) = '%s'" % table.lower())
112
113        names = []
114        description = {}
115        rs = cursor.fetchone()
116        while rs is not None:
117            (name, index_type, unique, max_extents, temporary) = rs
118            desc = 'Temporary ' if temporary == 'Y' else ''
119            unique = unique.lower()
120            if unique == 'nonunique':
121                unique = 'non-unique'
122            index_type = index_type.lower()
123            desc += '%s %s index' % (index_type, unique)
124            if max_extents:
125                desc += ' (max_extents=%d)' % max_extents
126            names.append(name)
127            description[name] = desc
128            rs = cursor.fetchone()
129
130        cursor.execute("SELECT aic.index_name, aic.column_name, "
131                       "aic.column_position, aic.descend, aic.table_owner, "
132                       "CASE alc.constraint_type WHEN 'U' THEN 'UNIQUE' "
133                       "WHEN 'P' THEN 'PRIMARY KEY' ELSE '' END "
134                       "AS index_type FROM all_ind_columns aic "
135                       "LEFT JOIN all_constraints alc "
136                       "ON aic.index_name = alc.constraint_name AND "
137                       "aic.table_name = alc.table_name AND "
138                       "aic.table_owner = alc.owner "
139                       "WHERE lower(aic.table_name) = '%s' "
140                       "ORDER BY COLUMN_POSITION" % table.lower())
141        rs = cursor.fetchone()
142        columns = {}
143        while rs is not None:
144            index_name = rs[0]
145            column_name = rs[1]
146            asc = rs[3]
147            cols = columns.get(index_name, [])
148            cols.append('%s %s' % (column_name, asc))
149            columns[index_name] = cols
150            rs = cursor.fetchone()
151
152        # Finally, assemble the result.
153        results = []
154        for name in names:
155            cols = columns.get(name, [])
156            desc = description.get(name, None)
157            results += [IndexMetadata(name, cols, desc)]
158
159        return results
160
161