1""" 2The framework for making database tests. 3""" 4 5import logging 6import os 7import re 8import sys 9from py.test import raises 10import sqlobject 11import sqlobject.conftest as conftest 12 13if sys.platform[:3] == "win": 14 def getcwd(): 15 return os.getcwd().replace('\\', '/') 16else: 17 getcwd = os.getcwd 18 19""" 20supportsMatrix defines what database backends support what features. 21Each feature has a name, if you see a key like '+featureName' then 22only the databases listed support the feature. Conversely, 23'-featureName' means all databases *except* the ones listed support 24the feature. The databases are given by their SQLObject string name, 25separated by spaces. 26 27The function supports(featureName) returns True or False based on this, 28and you can use it like:: 29 30 def test_featureX(): 31 if not supports('featureX'): 32 return 33""" 34supportsMatrix = { 35 '+exceptions': 'mysql postgres sqlite', 36 '-transactions': 'mysql rdbhost', 37 '-dropTableCascade': 'sybase mssql mysql', 38 '-expressionIndex': 'mysql sqlite firebird mssql', 39 '-blobData': 'mssql rdbhost', 40 '-decimalColumn': 'mssql', 41 '-emptyTable': 'mssql', 42 '-limitSelect' : 'mssql', 43 '+schema' : 'postgres', 44 '+memorydb': 'sqlite', 45 } 46 47 48def setupClass(soClasses, force=False): 49 """ 50 Makes sure the classes have a corresponding and correct table. 51 This won't recreate the table if it already exists. It will check 52 that the table is properly defined (in case you change your table 53 definition). 54 55 You can provide a single class or a list of classes; if a list 56 then classes will be created in the order you provide, and 57 destroyed in the opposite order. So if class A depends on class 58 B, then do setupClass([B, A]) and B won't be destroyed or cleared 59 until after A is destroyed or cleared. 60 61 If force is true, then the database will be recreated no matter 62 what. 63 """ 64 global hub 65 if not isinstance(soClasses, (list, tuple)): 66 soClasses = [soClasses] 67 connection = getConnection() 68 for soClass in soClasses: 69 ## This would be an alternate way to register connections... 70 #try: 71 # hub 72 #except NameError: 73 # hub = sqlobject.dbconnection.ConnectionHub() 74 #soClass._connection = hub 75 #hub.threadConnection = connection 76 #hub.processConnection = connection 77 soClass._connection = connection 78 installOrClear(soClasses, force=force) 79 return soClasses 80 81installedDBFilename = os.path.join(getcwd(), 'dbs_data.tmp') 82 83installedDBTracker = sqlobject.connectionForURI( 84 'sqlite:///' + installedDBFilename) 85 86def getConnection(**kw): 87 name = getConnectionURI() 88 conn = sqlobject.connectionForURI(name, **kw) 89 if conftest.option.show_sql: 90 conn.debug = True 91 if conftest.option.show_sql_output: 92 conn.debugOutput = True 93 return conn 94 95def getConnectionURI(): 96 name = conftest.option.Database 97 if name in conftest.connectionShortcuts: 98 name = conftest.connectionShortcuts[name] 99 return name 100 101try: 102 connection = getConnection() 103except Exception, e: 104 # At least this module should be importable... 105 print >> sys.stderr, ( 106 "Could not open database: %s" % e) 107 108 109class InstalledTestDatabase(sqlobject.SQLObject): 110 """ 111 This table is set up in SQLite (always, regardless of --Database) and 112 tracks what tables have been set up in the 'real' database. This 113 way we don't keep recreating the tables over and over when there 114 are multiple tests that use a table. 115 """ 116 117 _connection = installedDBTracker 118 table_name = sqlobject.StringCol(notNull=True) 119 createSQL = sqlobject.StringCol(notNull=True) 120 connectionURI = sqlobject.StringCol(notNull=True) 121 122 @classmethod 123 def installOrClear(cls, soClasses, force=False): 124 cls.setup() 125 reversed = list(soClasses)[:] 126 reversed.reverse() 127 # If anything needs to be dropped, they all must be dropped 128 # But if we're forcing it, then we'll always drop 129 if force: 130 any_drops = True 131 else: 132 any_drops = False 133 for soClass in reversed: 134 table = soClass.sqlmeta.table 135 if not soClass._connection.tableExists(table): 136 continue 137 items = list(cls.selectBy( 138 table_name=table, 139 connectionURI=soClass._connection.uri())) 140 if items: 141 instance = items[0] 142 sql = instance.createSQL 143 else: 144 sql = None 145 newSQL, constraints = soClass.createTableSQL() 146 if sql != newSQL: 147 if sql is not None: 148 instance.destroySelf() 149 any_drops = True 150 break 151 for soClass in reversed: 152 if soClass._connection.tableExists(soClass.sqlmeta.table): 153 if any_drops: 154 cls.drop(soClass) 155 else: 156 cls.clear(soClass) 157 for soClass in soClasses: 158 table = soClass.sqlmeta.table 159 if not soClass._connection.tableExists(table): 160 cls.install(soClass) 161 162 @classmethod 163 def install(cls, soClass): 164 """ 165 Creates the given table in its database. 166 """ 167 sql = getattr(soClass, soClass._connection.dbName + 'Create', 168 None) 169 all_extra = [] 170 if sql: 171 soClass._connection.query(sql) 172 else: 173 sql, extra_sql = soClass.createTableSQL() 174 soClass.createTable(applyConstraints=False) 175 all_extra.extend(extra_sql) 176 cls(table_name=soClass.sqlmeta.table, 177 createSQL=sql, 178 connectionURI=soClass._connection.uri()) 179 for extra_sql in all_extra: 180 soClass._connection.query(extra_sql) 181 182 @classmethod 183 def drop(cls, soClass): 184 """ 185 Drops a the given table from its database 186 """ 187 sql = getattr(soClass, soClass._connection.dbName + 'Drop', None) 188 if sql: 189 soClass._connection.query(sql) 190 else: 191 soClass.dropTable() 192 193 @classmethod 194 def clear(cls, soClass): 195 """ 196 Removes all the rows from a table. 197 """ 198 soClass.clearTable() 199 200 @classmethod 201 def setup(cls): 202 """ 203 This sets up *this* table. 204 """ 205 if not cls._connection.tableExists(cls.sqlmeta.table): 206 cls.createTable() 207 208installOrClear = InstalledTestDatabase.installOrClear 209 210class Dummy(object): 211 212 """ 213 Used for creating fake objects; a really poor 'mock object'. 214 """ 215 216 def __init__(self, **kw): 217 for name, value in kw.items(): 218 setattr(self, name, value) 219 220def inserts(cls, data, schema=None): 221 """ 222 Creates a bunch of rows. 223 224 You can use it like:: 225 226 inserts(Person, [{'fname': 'blah', 'lname': 'doe'}, ...]) 227 228 Or:: 229 230 inserts(Person, [('blah', 'doe')], schema= 231 ['fname', 'lname']) 232 233 If you give a single string for the `schema` then it'll split 234 that string to get the list of column names. 235 """ 236 if schema: 237 if isinstance(schema, str): 238 schema = schema.split() 239 keywordData = [] 240 for item in data: 241 itemDict = {} 242 for name, value in zip(schema, item): 243 itemDict[name] = value 244 keywordData.append(itemDict) 245 data = keywordData 246 results = [] 247 for args in data: 248 results.append(cls(**args)) 249 return results 250 251def supports(feature): 252 dbName = connection.dbName 253 support = supportsMatrix.get('+' + feature, None) 254 notSupport = supportsMatrix.get('-' + feature, None) 255 if support is not None and dbName in support.split(): 256 return True 257 elif support: 258 return False 259 if notSupport is not None and dbName in notSupport.split(): 260 return False 261 elif notSupport: 262 return True 263 assert notSupport is not None or support is not None, ( 264 "The supportMatrix does not list this feature: %r" 265 % feature) 266 267 268# To avoid name clashes: 269_inserts = inserts 270 271def setSQLiteConnectionFactory(TableClass, factory): 272 from sqlobject.sqlite.sqliteconnection import SQLiteConnection 273 conn = TableClass._connection 274 TableClass._connection = SQLiteConnection( 275 filename=conn.filename, 276 name=conn.name, debug=conn.debug, debugOutput=conn.debugOutput, 277 cache=conn.cache, style=conn.style, autoCommit=conn.autoCommit, 278 debugThreading=conn.debugThreading, registry=conn.registry, 279 factory=factory 280 ) 281 installOrClear([TableClass]) 282 283def deprecated_module(): 284 sqlobject.main.warnings_level = None 285 sqlobject.main.exception_level = None 286 287def setup_module(mod): 288 # modules with '_old' test backward compatible methods, so they 289 # don't get warnings or errors. 290 mod_name = str(mod.__name__) 291 if mod_name.endswith('/py'): 292 mod_name = mod_name[:-3] 293 if mod_name.endswith('_old'): 294 sqlobject.main.warnings_level = None 295 sqlobject.main.exception_level = None 296 else: 297 sqlobject.main.warnings_level = None 298 sqlobject.main.exception_level = 0 299 300def teardown_module(mod=None): 301 sqlobject.main.warnings_level = None 302 sqlobject.main.exception_level = 0 303 304def setupLogging(): 305 fmt = '[%(asctime)s] %(name)s %(levelname)s: %(message)s' 306 formatter = logging.Formatter(fmt) 307 hdlr = logging.StreamHandler(sys.stderr) 308 hdlr.setFormatter(formatter) 309 hdlr.setLevel(logging.NOTSET) 310 logger = logging.getLogger() 311 logger.addHandler(hdlr) 312 313__all__ = ['getConnection', 'getConnectionURI', 'setupClass', 'Dummy', 'raises', 314 'inserts', 'supports', 'deprecated_module', 315 'setup_module', 'teardown_module', 'setupLogging'] 316