1"""
2The framework for making database tests.
3"""
4
5import logging
6import os
7import re
8import sys
9from py.test import raises
10import sqlobject
11import sqlobject.conftest as conftest
12
13if sys.platform[:3] == "win":
14    def getcwd():
15        return os.getcwd().replace('\\', '/')
16else:
17    getcwd = os.getcwd
18
19"""
20supportsMatrix defines what database backends support what features.
21Each feature has a name, if you see a key like '+featureName' then
22only the databases listed support the feature.  Conversely,
23'-featureName' means all databases *except* the ones listed support
24the feature.  The databases are given by their SQLObject string name,
25separated by spaces.
26
27The function supports(featureName) returns True or False based on this,
28and you can use it like::
29
30    def test_featureX():
31        if not supports('featureX'):
32            return
33"""
34supportsMatrix = {
35    '+exceptions': 'mysql postgres sqlite',
36    '-transactions': 'mysql rdbhost',
37    '-dropTableCascade': 'sybase mssql mysql',
38    '-expressionIndex': 'mysql sqlite firebird mssql',
39    '-blobData': 'mssql rdbhost',
40    '-decimalColumn': 'mssql',
41    '-emptyTable': 'mssql',
42    '-limitSelect' : 'mssql',
43    '+schema' : 'postgres',
44    '+memorydb': 'sqlite',
45    }
46
47
48def setupClass(soClasses, force=False):
49    """
50    Makes sure the classes have a corresponding and correct table.
51    This won't recreate the table if it already exists.  It will check
52    that the table is properly defined (in case you change your table
53    definition).
54
55    You can provide a single class or a list of classes; if a list
56    then classes will be created in the order you provide, and
57    destroyed in the opposite order.  So if class A depends on class
58    B, then do setupClass([B, A]) and B won't be destroyed or cleared
59    until after A is destroyed or cleared.
60
61    If force is true, then the database will be recreated no matter
62    what.
63    """
64    global hub
65    if not isinstance(soClasses, (list, tuple)):
66        soClasses = [soClasses]
67    connection = getConnection()
68    for soClass in soClasses:
69        ## This would be an alternate way to register connections...
70        #try:
71        #    hub
72        #except NameError:
73        #    hub = sqlobject.dbconnection.ConnectionHub()
74        #soClass._connection = hub
75        #hub.threadConnection = connection
76        #hub.processConnection = connection
77        soClass._connection = connection
78    installOrClear(soClasses, force=force)
79    return soClasses
80
81installedDBFilename = os.path.join(getcwd(), 'dbs_data.tmp')
82
83installedDBTracker = sqlobject.connectionForURI(
84    'sqlite:///' + installedDBFilename)
85
86def getConnection(**kw):
87    name = getConnectionURI()
88    conn = sqlobject.connectionForURI(name, **kw)
89    if conftest.option.show_sql:
90        conn.debug = True
91    if conftest.option.show_sql_output:
92        conn.debugOutput = True
93    return conn
94
95def getConnectionURI():
96    name = conftest.option.Database
97    if name in conftest.connectionShortcuts:
98        name = conftest.connectionShortcuts[name]
99    return name
100
101try:
102    connection = getConnection()
103except Exception, e:
104    # At least this module should be importable...
105    print >> sys.stderr, (
106        "Could not open database: %s" % e)
107
108
109class InstalledTestDatabase(sqlobject.SQLObject):
110    """
111    This table is set up in SQLite (always, regardless of --Database) and
112    tracks what tables have been set up in the 'real' database.  This
113    way we don't keep recreating the tables over and over when there
114    are multiple tests that use a table.
115    """
116
117    _connection = installedDBTracker
118    table_name = sqlobject.StringCol(notNull=True)
119    createSQL = sqlobject.StringCol(notNull=True)
120    connectionURI = sqlobject.StringCol(notNull=True)
121
122    @classmethod
123    def installOrClear(cls, soClasses, force=False):
124        cls.setup()
125        reversed = list(soClasses)[:]
126        reversed.reverse()
127        # If anything needs to be dropped, they all must be dropped
128        # But if we're forcing it, then we'll always drop
129        if force:
130            any_drops = True
131        else:
132            any_drops = False
133        for soClass in reversed:
134            table = soClass.sqlmeta.table
135            if not soClass._connection.tableExists(table):
136                continue
137            items = list(cls.selectBy(
138                table_name=table,
139                connectionURI=soClass._connection.uri()))
140            if items:
141                instance = items[0]
142                sql = instance.createSQL
143            else:
144                sql = None
145            newSQL, constraints = soClass.createTableSQL()
146            if sql != newSQL:
147                if sql is not None:
148                    instance.destroySelf()
149                any_drops = True
150                break
151        for soClass in reversed:
152            if soClass._connection.tableExists(soClass.sqlmeta.table):
153                if any_drops:
154                    cls.drop(soClass)
155                else:
156                    cls.clear(soClass)
157        for soClass in soClasses:
158            table = soClass.sqlmeta.table
159            if not soClass._connection.tableExists(table):
160                cls.install(soClass)
161
162    @classmethod
163    def install(cls, soClass):
164        """
165        Creates the given table in its database.
166        """
167        sql = getattr(soClass, soClass._connection.dbName + 'Create',
168                      None)
169        all_extra = []
170        if sql:
171            soClass._connection.query(sql)
172        else:
173            sql, extra_sql = soClass.createTableSQL()
174            soClass.createTable(applyConstraints=False)
175            all_extra.extend(extra_sql)
176        cls(table_name=soClass.sqlmeta.table,
177            createSQL=sql,
178            connectionURI=soClass._connection.uri())
179        for extra_sql in all_extra:
180            soClass._connection.query(extra_sql)
181
182    @classmethod
183    def drop(cls, soClass):
184        """
185        Drops a the given table from its database
186        """
187        sql = getattr(soClass, soClass._connection.dbName + 'Drop', None)
188        if sql:
189            soClass._connection.query(sql)
190        else:
191            soClass.dropTable()
192
193    @classmethod
194    def clear(cls, soClass):
195        """
196        Removes all the rows from a table.
197        """
198        soClass.clearTable()
199
200    @classmethod
201    def setup(cls):
202        """
203        This sets up *this* table.
204        """
205        if not cls._connection.tableExists(cls.sqlmeta.table):
206            cls.createTable()
207
208installOrClear = InstalledTestDatabase.installOrClear
209
210class Dummy(object):
211
212    """
213    Used for creating fake objects; a really poor 'mock object'.
214    """
215
216    def __init__(self, **kw):
217        for name, value in kw.items():
218            setattr(self, name, value)
219
220def inserts(cls, data, schema=None):
221    """
222    Creates a bunch of rows.
223
224    You can use it like::
225
226        inserts(Person, [{'fname': 'blah', 'lname': 'doe'}, ...])
227
228    Or::
229
230        inserts(Person, [('blah', 'doe')], schema=
231                ['fname', 'lname'])
232
233    If you give a single string for the `schema` then it'll split
234    that string to get the list of column names.
235    """
236    if schema:
237        if isinstance(schema, str):
238            schema = schema.split()
239        keywordData = []
240        for item in data:
241            itemDict = {}
242            for name, value in zip(schema, item):
243                itemDict[name] = value
244            keywordData.append(itemDict)
245        data = keywordData
246    results = []
247    for args in data:
248        results.append(cls(**args))
249    return results
250
251def supports(feature):
252    dbName = connection.dbName
253    support = supportsMatrix.get('+' + feature, None)
254    notSupport = supportsMatrix.get('-' + feature, None)
255    if support is not None and dbName in support.split():
256        return True
257    elif support:
258        return False
259    if notSupport is not None and dbName in notSupport.split():
260        return False
261    elif notSupport:
262        return True
263    assert notSupport is not None or support is not None, (
264        "The supportMatrix does not list this feature: %r"
265        % feature)
266
267
268# To avoid name clashes:
269_inserts = inserts
270
271def setSQLiteConnectionFactory(TableClass, factory):
272    from sqlobject.sqlite.sqliteconnection import SQLiteConnection
273    conn = TableClass._connection
274    TableClass._connection = SQLiteConnection(
275        filename=conn.filename,
276        name=conn.name, debug=conn.debug, debugOutput=conn.debugOutput,
277        cache=conn.cache, style=conn.style, autoCommit=conn.autoCommit,
278        debugThreading=conn.debugThreading, registry=conn.registry,
279        factory=factory
280    )
281    installOrClear([TableClass])
282
283def deprecated_module():
284    sqlobject.main.warnings_level = None
285    sqlobject.main.exception_level = None
286
287def setup_module(mod):
288    # modules with '_old' test backward compatible methods, so they
289    # don't get warnings or errors.
290    mod_name = str(mod.__name__)
291    if mod_name.endswith('/py'):
292        mod_name = mod_name[:-3]
293    if mod_name.endswith('_old'):
294        sqlobject.main.warnings_level = None
295        sqlobject.main.exception_level = None
296    else:
297        sqlobject.main.warnings_level = None
298        sqlobject.main.exception_level = 0
299
300def teardown_module(mod=None):
301    sqlobject.main.warnings_level = None
302    sqlobject.main.exception_level = 0
303
304def setupLogging():
305    fmt = '[%(asctime)s] %(name)s %(levelname)s: %(message)s'
306    formatter = logging.Formatter(fmt)
307    hdlr = logging.StreamHandler(sys.stderr)
308    hdlr.setFormatter(formatter)
309    hdlr.setLevel(logging.NOTSET)
310    logger = logging.getLogger()
311    logger.addHandler(hdlr)
312
313__all__ = ['getConnection', 'getConnectionURI', 'setupClass', 'Dummy', 'raises',
314           'inserts', 'supports', 'deprecated_module',
315           'setup_module', 'teardown_module', 'setupLogging']
316