1#!/usr/bin/env python 2# 3# odb.py 4# 5# Object Database Api 6# 7# Written by David Jeske <jeske@neotonic.com>, 2001/07. 8# Inspired by eGroups' sqldb.py originally written by Scott Hassan circa 1998. 9# 10# Copyright (C) 2001, by David Jeske and Neotonic 11# 12# Goals: 13# - a simple object-like interface to database data 14# - database independent (someday) 15# - relational-style "rigid schema definition" 16# - object style easy-access 17# 18# Example: 19# 20# import odb 21# 22# # define table 23# class AgentsTable(odb.Table): 24# def _defineRows(self): 25# self.d_addColumn("agent_id",kInteger,None,primarykey = 1,autoincrement = 1) 26# self.d_addColumn("login",kVarString,200,notnull=1) 27# self.d_addColumn("ticket_count",kIncInteger,None) 28# 29# if __name__ == "__main__": 30# # open database 31# ndb = MySQLdb.connect(host = 'localhost', 32# user='username', 33# passwd = 'password', 34# db='testdb') 35# db = Database(ndb) 36# tbl = AgentsTable(db,"agents") 37# 38# # create row 39# agent_row = tbl.newRow() 40# agent_row.login = "foo" 41# agent_row.save() 42# 43# # fetch row (must use primary key) 44# try: 45# get_row = tbl.fetchRow( ('agent_id', agent_row.agent_id) ) 46# except odb.eNoMatchingRows: 47# print "this is bad, we should have found the row" 48# 49# # fetch rows (can return empty list) 50# list_rows = tbl.fetchRows( ('login', "foo") ) 51# 52 53import string 54import sys, zlib 55from log import * 56 57import handle_error 58 59eNoSuchColumn = "odb.eNoSuchColumn" 60eNonUniqueMatchSpec = "odb.eNonUniqueMatchSpec" 61eNoMatchingRows = "odb.eNoMatchingRows" 62eInternalError = "odb.eInternalError" 63eInvalidMatchSpec = "odb.eInvalidMatchSpec" 64eInvalidData = "odb.eInvalidData" 65eUnsavedObjectLost = "odb.eUnsavedObjectLost" 66eDuplicateKey = "odb.eDuplicateKey" 67 68##################################### 69# COLUMN TYPES 70################ ###################### 71# typename ####################### size data means: 72# # # 73kInteger = "kInteger" # - 74kFixedString = "kFixedString" # size 75kVarString = "kVarString" # maxsize 76kBigString = "kBigString" # - 77kIncInteger = "kIncInteger" # - 78kDateTime = "kDateTime" 79kTimeStamp = "kTimeStamp" 80kReal = "kReal" 81 82 83DEBUG = 0 84 85############## 86# Database 87# 88# this will ultimately turn into a mostly abstract base class for 89# the DB adaptors for different database types.... 90# 91 92class Database: 93 def __init__(self, db, debug=0): 94 self._tables = {} 95 self.db = db 96 self._cursor = None 97 self.compression_enabled = 0 98 self.debug = debug 99 self.SQLError = None 100 101 self.__defaultRowClass = self.defaultRowClass() 102 self.__defaultRowListClass = self.defaultRowListClass() 103 104 def defaultCursor(self): 105 if self._cursor is None: 106 self._cursor = self.db.cursor() 107 return self._cursor 108 109 def escape(self,str): 110 raise "Unimplemented Error" 111 112 def getDefaultRowClass(self): return self.__defaultRowClass 113 def setDefaultRowClass(self, clss): self.__defaultRowClass = clss 114 def getDefaultRowListClass(self): return self.__defaultRowListClass 115 def setDefaultRowListClass(self, clss): self.__defaultRowListClass = clss 116 117 def defaultRowClass(self): 118 return Row 119 120 def defaultRowListClass(self): 121 # base type is list... 122 return list 123 124 def addTable(self, attrname, tblname, tblclass, 125 rowClass = None, check = 0, create = 0, rowListClass = None): 126 tbl = tblclass(self, tblname, rowClass=rowClass, check=check, 127 create=create, rowListClass=rowListClass) 128 self._tables[attrname] = tbl 129 return tbl 130 131 def close(self): 132 for name, tbl in self._tables.items(): 133 tbl.db = None 134 self._tables = {} 135 if self.db is not None: 136 self.db.close() 137 self.db = None 138 139 def __getattr__(self, key): 140 if key == "_tables": 141 raise AttributeError, "odb.Database: not initialized properly, self._tables does not exist" 142 143 try: 144 table_dict = getattr(self,"_tables") 145 return table_dict[key] 146 except KeyError: 147 raise AttributeError, "odb.Database: unknown attribute %s" % (key) 148 149 def beginTransaction(self, cursor=None): 150 if cursor is None: 151 cursor = self.defaultCursor() 152 dlog(DEV_UPDATE,"begin") 153 cursor.execute("begin") 154 155 def commitTransaction(self, cursor=None): 156 if cursor is None: 157 cursor = self.defaultCursor() 158 dlog(DEV_UPDATE,"commit") 159 cursor.execute("commit") 160 161 def rollbackTransaction(self, cursor=None): 162 if cursor is None: 163 cursor = self.defaultCursor() 164 dlog(DEV_UPDATE,"rollback") 165 cursor.execute("rollback") 166 167 ## 168 ## schema creation code 169 ## 170 171 def createTables(self): 172 tables = self.listTables() 173 174 for attrname, tbl in self._tables.items(): 175 tblname = tbl.getTableName() 176 177 if tblname not in tables: 178 print "table %s does not exist" % tblname 179 tbl.createTable() 180 else: 181 invalidAppCols, invalidDBCols = tbl.checkTable() 182 183## self.alterTableToMatch(tbl) 184 185 def createIndices(self): 186 indices = self.listIndices() 187 188 for attrname, tbl in self._tables.items(): 189 for indexName, (columns, unique) in tbl.getIndices().items(): 190 if indexName in indices: continue 191 192 tbl.createIndex(columns, indexName=indexName, unique=unique) 193 194 def synchronizeSchema(self): 195 tables = self.listTables() 196 197 for attrname, tbl in self._tables.items(): 198 tblname = tbl.getTableName() 199 self.alterTableToMatch(tbl) 200 201 def listTables(self, cursor=None): 202 raise "Unimplemented Error" 203 204 def listFieldsDict(self, table_name, cursor=None): 205 raise "Unimplemented Error" 206 207 def listFields(self, table_name, cursor=None): 208 columns = self.listFieldsDict(table_name, cursor=cursor) 209 return columns.keys() 210 211########################################## 212# Table 213# 214 215 216class Table: 217 def subclassinit(self): 218 pass 219 def __init__(self,database,table_name, 220 rowClass = None, check = 0, create = 0, rowListClass = None): 221 self.db = database 222 self.__table_name = table_name 223 if rowClass: 224 self.__defaultRowClass = rowClass 225 else: 226 self.__defaultRowClass = database.getDefaultRowClass() 227 228 if rowListClass: 229 self.__defaultRowListClass = rowListClass 230 else: 231 self.__defaultRowListClass = database.getDefaultRowListClass() 232 233 # get this stuff ready! 234 235 self.__column_list = [] 236 self.__vcolumn_list = [] 237 self.__columns_locked = 0 238 self.__has_value_column = 0 239 240 self.__indices = {} 241 242 # this will be used during init... 243 self.__col_def_hash = None 244 self.__vcol_def_hash = None 245 self.__primary_key_list = None 246 self.__relations_by_table = {} 247 248 # ask the subclass to def his rows 249 self._defineRows() 250 251 # get ready to run! 252 self.__lockColumnsAndInit() 253 254 self.subclassinit() 255 256 if create: 257 self.createTable() 258 259 if check: 260 self.checkTable() 261 262 def _colTypeToSQLType(self, colname, coltype, options): 263 264 if coltype == kInteger: 265 coltype = "integer" 266 elif coltype == kFixedString: 267 sz = options.get('size', None) 268 if sz is None: coltype = 'char' 269 else: coltype = "char(%s)" % sz 270 elif coltype == kVarString: 271 sz = options.get('size', None) 272 if sz is None: coltype = 'varchar' 273 else: coltype = "varchar(%s)" % sz 274 elif coltype == kBigString: 275 coltype = "text" 276 elif coltype == kIncInteger: 277 coltype = "integer" 278 elif coltype == kDateTime: 279 coltype = "datetime" 280 elif coltype == kTimeStamp: 281 coltype = "timestamp" 282 elif coltype == kReal: 283 coltype = "real" 284 285 coldef = "%s %s" % (colname, coltype) 286 287 if options.get('notnull', 0): coldef = coldef + " NOT NULL" 288 if options.get('autoincrement', 0): coldef = coldef + " AUTO_INCREMENT" 289 if options.get('unique', 0): coldef = coldef + " UNIQUE" 290# if options.get('primarykey', 0): coldef = coldef + " primary key" 291 if options.get('default', None) is not None: coldef = coldef + " DEFAULT %s" % options.get('default') 292 293 return coldef 294 295 def getTableName(self): return self.__table_name 296 def setTableName(self, tablename): self.__table_name = tablename 297 298 def getIndices(self): return self.__indices 299 300 def _createTableSQL(self): 301 defs = [] 302 for colname, coltype, options in self.__column_list: 303 defs.append(self._colTypeToSQLType(colname, coltype, options)) 304 305 defs = string.join(defs, ", ") 306 307 primarykeys = self.getPrimaryKeyList() 308 primarykey_str = "" 309 if primarykeys: 310 primarykey_str = ", PRIMARY KEY (" + string.join(primarykeys, ",") + ")" 311 312 sql = "create table %s (%s %s)" % (self.__table_name, defs, primarykey_str) 313 return sql 314 315 def createTable(self, cursor=None): 316 if cursor is None: cursor = self.db.defaultCursor() 317 sql = self._createTableSQL() 318 print "CREATING TABLE:", sql 319 cursor.execute(sql) 320 321 def dropTable(self, cursor=None): 322 if cursor is None: cursor = self.db.defaultCursor() 323 try: 324 cursor.execute("drop table %s" % self.__table_name) # clean out the table 325 except self.SQLError, reason: 326 pass 327 328 def renameTable(self, newTableName, cursor=None): 329 if cursor is None: cursor = self.db.defaultCursor() 330 try: 331 cursor.execute("rename table %s to %s" % (self.__table_name, newTableName)) 332 except sel.SQLError, reason: 333 pass 334 335 self.setTableName(newTableName) 336 337 def getTableColumnsFromDB(self): 338 return self.db.listFieldsDict(self.__table_name) 339 340 def checkTable(self, warnflag=1): 341 invalidDBCols = {} 342 invalidAppCols = {} 343 344 dbcolumns = self.getTableColumnsFromDB() 345 for coldef in self.__column_list: 346 colname = coldef[0] 347 348 dbcoldef = dbcolumns.get(colname, None) 349 if dbcoldef is None: 350 invalidAppCols[colname] = 1 351 352 for colname, row in dbcolumns.items(): 353 coldef = self.__col_def_hash.get(colname, None) 354 if coldef is None: 355 invalidDBCols[colname] = 1 356 357 if warnflag == 1: 358 if invalidDBCols: 359 print "----- WARNING ------------------------------------------" 360 print " There are columns defined in the database schema that do" 361 print " not match the application's schema." 362 print " columns:", invalidDBCols.keys() 363 print "--------------------------------------------------------" 364 365 if invalidAppCols: 366 print "----- WARNING ------------------------------------------" 367 print " There are new columns defined in the application schema" 368 print " that do not match the database's schema." 369 print " columns:", invalidAppCols.keys() 370 print "--------------------------------------------------------" 371 372 return invalidAppCols, invalidDBCols 373 374 375 def alterTableToMatch(self): 376 raise "Unimplemented Error!" 377 378 def addIndex(self, columns, indexName=None, unique=0): 379 if indexName is None: 380 indexName = self.getTableName() + "_index_" + string.join(columns, "_") 381 382 self.__indices[indexName] = (columns, unique) 383 384 def createIndex(self, columns, indexName=None, unique=0, cursor=None): 385 if cursor is None: cursor = self.db.defaultCursor() 386 cols = string.join(columns, ",") 387 388 if indexName is None: 389 indexName = self.getTableName() + "_index_" + string.join(columns, "_") 390 391 uniquesql = "" 392 if unique: 393 uniquesql = " unique" 394 sql = "create %s index %s on %s (%s)" % (uniquesql, indexName, self.getTableName(), cols) 395 warn("creating index", sql) 396 cursor.execute(sql) 397 398 399 ## Column Definition 400 401 def getColumnDef(self,column_name): 402 try: 403 return self.__col_def_hash[column_name] 404 except KeyError: 405 try: 406 return self.__vcol_def_hash[column_name] 407 except KeyError: 408 raise eNoSuchColumn, "no column (%s) on table %s" % (column_name,self.__table_name) 409 410 def getColumnList(self): 411 return self.__column_list + self.__vcolumn_list 412 def getAppColumnList(self): return self.__column_list 413 414 def databaseSizeForData_ColumnName_(self,data,col_name): 415 try: 416 col_def = self.__col_def_hash[col_name] 417 except KeyError: 418 try: 419 col_def = self.__vcol_def_hash[col_name] 420 except KeyError: 421 raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name) 422 423 c_name,c_type,c_options = col_def 424 425 if c_type == kBigString: 426 if c_options.get("compress_ok",0) and self.db.compression_enabled: 427 z_size = len(zlib.compress(data,9)) 428 r_size = len(data) 429 if z_size < r_size: 430 return z_size 431 else: 432 return r_size 433 else: 434 return len(data) 435 else: 436 # really simplistic database size computation: 437 try: 438 a = data[0] 439 return len(data) 440 except: 441 return 4 442 443 444 def columnType(self, col_name): 445 try: 446 col_def = self.__col_def_hash[col_name] 447 except KeyError: 448 try: 449 col_def = self.__vcol_def_hash[col_name] 450 except KeyError: 451 raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name) 452 453 c_name,c_type,c_options = col_def 454 return c_type 455 456 def convertDataForColumn(self,data,col_name): 457 try: 458 col_def = self.__col_def_hash[col_name] 459 except KeyError: 460 try: 461 col_def = self.__vcol_def_hash[col_name] 462 except KeyError: 463 raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name) 464 465 c_name,c_type,c_options = col_def 466 467 if c_type == kIncInteger: 468 raise eInvalidData, "invalid operation for column (%s:%s) on table (%s)" % (col_name,c_type,self.__table_name) 469 470 if c_type == kInteger: 471 try: 472 if data is None: data = 0 473 else: return long(data) 474 except (ValueError,TypeError): 475 raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data),col_name,c_type,self.__table_name) 476 elif c_type == kReal: 477 try: 478 if data is None: data = 0.0 479 else: return float(data) 480 except (ValueError,TypeError): 481 raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data), col_name,c_type,self.__table_name) 482 483 else: 484 if type(data) == type(long(0)): 485 return "%d" % data 486 else: 487 return str(data) 488 489 def getPrimaryKeyList(self): 490 return self.__primary_key_list 491 492 def hasValueColumn(self): 493 return self.__has_value_column 494 495 def hasColumn(self,name): 496 return self.__col_def_hash.has_key(name) 497 def hasVColumn(self,name): 498 return self.__vcol_def_hash.has_key(name) 499 500 501 def _defineRows(self): 502 raise "can't instantiate base odb.Table type, make a subclass and override _defineRows()" 503 504 def __lockColumnsAndInit(self): 505 # add a 'odb_value column' before we lockdown the table def 506 if self.__has_value_column: 507 self.d_addColumn("odb_value",kBigText,default='') 508 509 self.__columns_locked = 1 510 # walk column list and make lookup hashes, primary_key_list, etc.. 511 512 primary_key_list = [] 513 col_def_hash = {} 514 for a_col in self.__column_list: 515 name,type,options = a_col 516 col_def_hash[name] = a_col 517 if options.has_key('primarykey'): 518 primary_key_list.append(name) 519 520 self.__col_def_hash = col_def_hash 521 self.__primary_key_list = primary_key_list 522 523 # setup the value columns! 524 525 if (not self.__has_value_column) and (len(self.__vcolumn_list) > 0): 526 raise "can't define vcolumns on table without ValueColumn, call d_addValueColumn() in your _defineRows()" 527 528 vcol_def_hash = {} 529 for a_col in self.__vcolumn_list: 530 name,type,size_data,options = a_col 531 vcol_def_hash[name] = a_col 532 533 self.__vcol_def_hash = vcol_def_hash 534 535 536 def __checkColumnLock(self): 537 if self.__columns_locked: 538 raise "can't change column definitions outside of subclass' _defineRows() method!" 539 540 # table definition methods, these are only available while inside the 541 # subclass's _defineRows method 542 # 543 # Ex: 544 # 545 # import odb 546 # class MyTable(odb.Table): 547 # def _defineRows(self): 548 # self.d_addColumn("id",kInteger,primarykey = 1,autoincrement = 1) 549 # self.d_addColumn("name",kVarString,120) 550 # self.d_addColumn("type",kInteger, 551 # enum_values = { 0 : "alive", 1 : "dead" } 552 553 def d_addColumn(self,col_name,ctype,size=None,primarykey = 0, 554 notnull = 0,indexed=0, 555 default=None,unique=0,autoincrement=0,safeupdate=0, 556 enum_values = None, 557 no_export = 0, 558 relations=None,compress_ok=0,int_date=0): 559 560 self.__checkColumnLock() 561 562 options = {} 563 options['default'] = default 564 if primarykey: 565 options['primarykey'] = primarykey 566 if unique: 567 options['unique'] = unique 568 if indexed: 569 options['indexed'] = indexed 570 self.addIndex((col_name,)) 571 if safeupdate: 572 options['safeupdate'] = safeupdate 573 if autoincrement: 574 options['autoincrement'] = autoincrement 575 if notnull: 576 options['notnull'] = notnull 577 if size: 578 options['size'] = size 579 if no_export: 580 options['no_export'] = no_export 581 if int_date: 582 if ctype != kInteger: 583 raise eInvalidData, "can't flag columns int_date unless they are kInteger" 584 else: 585 options['int_date'] = int_date 586 587 if enum_values: 588 options['enum_values'] = enum_values 589 inv_enum_values = {} 590 for k,v in enum_values.items(): 591 if inv_enum_values.has_key(v): 592 raise eInvalidData, "enum_values paramater must be a 1 to 1 mapping for Table(%s)" % self.__table_name 593 else: 594 inv_enum_values[v] = k 595 options['inv_enum_values'] = inv_enum_values 596 if relations: 597 options['relations'] = relations 598 for a_relation in relations: 599 table, foreign_column_name = a_relation 600 if self.__relations_by_table.has_key(table): 601 raise eInvalidData, "multiple relations for the same foreign table are not yet supported" 602 self.__relations_by_table[table] = (col_name,foreign_column_name) 603 if compress_ok: 604 if ctype == kBigString: 605 options['compress_ok'] = 1 606 else: 607 raise eInvalidData, "only kBigString fields can be compress_ok=1" 608 609 self.__column_list.append( (col_name,ctype,options) ) 610 611 def d_addValueColumn(self): 612 self.__checkColumnLock() 613 self.__has_value_column = 1 614 615 def d_addVColumn(self,col_name,type,size=None,default=None): 616 self.__checkColumnLock() 617 618 if (not self.__has_value_column): 619 raise "can't define VColumns on table without ValueColumn, call d_addValueColumn() first" 620 621 options = {} 622 if default: 623 options['default'] = default 624 if size: 625 options['size'] = size 626 627 self.__vcolumn_list.append( (col_name,type,options) ) 628 629 ##################### 630 # _checkColMatchSpec(col_match_spec,should_match_unique_row = 0) 631 # 632 # raise an error if the col_match_spec contains invalid columns, or 633 # (in the case of should_match_unique_row) if it does not fully specify 634 # a unique row. 635 # 636 # NOTE: we don't currently support where clauses with value column fields! 637 # 638 639 def _fixColMatchSpec(self,col_match_spec, should_match_unique_row = 0): 640 if type(col_match_spec) == type([]): 641 if type(col_match_spec[0]) != type((0,)): 642 raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)" 643 elif type(col_match_spec) == type((0,)): 644 col_match_spec = [ col_match_spec ] 645 elif type(col_match_spec) == type(None): 646 if should_match_unique_row: 647 raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec 648 else: 649 return None 650 else: 651 raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)" 652 653 if should_match_unique_row: 654 unique_column_lists = [] 655 656 # first the primary key list 657 my_primary_key_list = [] 658 for a_key in self.__primary_key_list: 659 my_primary_key_list.append(a_key) 660 661 # then other unique keys 662 for a_col in self.__column_list: 663 col_name,a_type,options = a_col 664 if options.has_key('unique'): 665 unique_column_lists.append( (col_name, [col_name]) ) 666 667 unique_column_lists.append( ('primary_key', my_primary_key_list) ) 668 669 670 new_col_match_spec = [] 671 for a_col in col_match_spec: 672 name,val = a_col 673 # newname = string.lower(name) 674 # what is this doing?? - jeske 675 newname = name 676 if not self.__col_def_hash.has_key(newname): 677 raise eNoSuchColumn, "no such column in match spec: '%s'" % newname 678 679 new_col_match_spec.append( (newname,val) ) 680 681 if should_match_unique_row: 682 for name,a_list in unique_column_lists: 683 try: 684 a_list.remove(newname) 685 except ValueError: 686 # it's okay if they specify too many columns! 687 pass 688 689 if should_match_unique_row: 690 for name,a_list in unique_column_lists: 691 if len(a_list) == 0: 692 # we matched at least one unique colum spec! 693 # log("using unique column (%s) for query %s" % (name,col_match_spec)) 694 return new_col_match_spec 695 696 raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec 697 698 return new_col_match_spec 699 700 def __buildWhereClause (self, col_match_spec,other_clauses = None): 701 sql_where_list = [] 702 703 if not col_match_spec is None: 704 for m_col in col_match_spec: 705 m_col_name,m_col_val = m_col 706 c_name,c_type,c_options = self.__col_def_hash[m_col_name] 707 if c_type in (kIncInteger, kInteger): 708 try: 709 m_col_val_long = long(m_col_val) 710 except ValueError: 711 raise ValueError, "invalid literal for long(%s) in table %s" % (repr(m_col_val),self.__table_name) 712 713 sql_where_list.append("%s = %d" % (c_name, m_col_val_long)) 714 elif c_type == kReal: 715 try: 716 m_col_val_float = float(m_col_val) 717 except ValueError: 718 raise ValueError, "invalid literal for float(%s) is table %s" % (repr(m_col_val), self.__table_name) 719 sql_where_list.append("%s = %s" % (c_name, m_col_val_float)) 720 else: 721 sql_where_list.append("%s = '%s'" % (c_name, self.db.escape(m_col_val))) 722 723 if other_clauses is None: 724 pass 725 elif type(other_clauses) == type(""): 726 sql_where_list = sql_where_list + [other_clauses] 727 elif type(other_clauses) == type([]): 728 sql_where_list = sql_where_list + other_clauses 729 else: 730 raise eInvalidData, "unknown type of extra where clause: %s" % repr(other_clauses) 731 732 return sql_where_list 733 734 def __fetchRows(self,col_match_spec,cursor = None, where = None, order_by = None, limit_to = None, 735 skip_to = None, join = None): 736 if cursor is None: 737 cursor = self.db.defaultCursor() 738 739 # build column list 740 sql_columns = [] 741 for name,t,options in self.__column_list: 742 sql_columns.append(name) 743 744 # build join information 745 746 joined_cols = [] 747 joined_cols_hash = {} 748 join_clauses = [] 749 if not join is None: 750 for a_table,retrieve_foreign_cols in join: 751 try: 752 my_col,foreign_col = self.__relations_by_table[a_table] 753 for a_col in retrieve_foreign_cols: 754 full_col_name = "%s.%s" % (my_col,a_col) 755 joined_cols_hash[full_col_name] = 1 756 joined_cols.append(full_col_name) 757 sql_columns.append( full_col_name ) 758 759 join_clauses.append(" left join %s as %s on %s=%s " % (a_table,my_col,my_col,foreign_col)) 760 761 except KeyError: 762 eInvalidJoinSpec, "can't find table %s in defined relations for %s" % (a_table,self.__table_name) 763 764 # start buildling SQL 765 sql = "select %s from %s" % (string.join(sql_columns,","), 766 self.__table_name) 767 768 # add join clause 769 if join_clauses: 770 sql = sql + string.join(join_clauses," ") 771 772 # add where clause elements 773 sql_where_list = self.__buildWhereClause (col_match_spec,where) 774 if sql_where_list: 775 sql = sql + " where %s" % (string.join(sql_where_list," and ")) 776 777 # add order by clause 778 if order_by: 779 sql = sql + " order by %s " % string.join(order_by,",") 780 781 # add limit 782 if not limit_to is None: 783 if not skip_to is None: 784# log("limit,skip = %s,%s" % (limit_to,skip_to)) 785 if self.db.db.__module__ == "sqlite.main": 786 sql = sql + " limit %s offset %s " % (limit_to,skip_to) 787 else: 788 sql = sql + " limit %s, %s" % (skip_to,limit_to) 789 else: 790 sql = sql + " limit %s" % limit_to 791 else: 792 if not skip_to is None: 793 raise eInvalidData, "can't specify skip_to without limit_to in MySQL" 794 795 dlog(DEV_SELECT,sql) 796 cursor.execute(sql) 797 798 # create defaultRowListClass instance... 799 return_rows = self.__defaultRowListClass() 800 801 # should do fetchmany! 802 all_rows = cursor.fetchall() 803 for a_row in all_rows: 804 data_dict = {} 805 806 col_num = 0 807 808 # for a_col in cursor.description: 809 # (name,type_code,display_size,internal_size,precision,scale,null_ok) = a_col 810 for name in sql_columns: 811 if self.__col_def_hash.has_key(name) or joined_cols_hash.has_key(name): 812 # only include declared columns! 813 if self.__col_def_hash.has_key(name): 814 c_name,c_type,c_options = self.__col_def_hash[name] 815 if c_type == kBigString and c_options.get("compress_ok",0) and a_row[col_num]: 816 try: 817 a_col_data = zlib.decompress(a_row[col_num]) 818 except zlib.error: 819 a_col_data = a_row[col_num] 820 821 data_dict[name] = a_col_data 822 elif c_type == kInteger or c_type == kIncInteger: 823 value = a_row[col_num] 824 if not value is None: 825 data_dict[name] = int(value) 826 else: 827 data_dict[name] = None 828 else: 829 data_dict[name] = a_row[col_num] 830 831 else: 832 data_dict[name] = a_row[col_num] 833 834 col_num = col_num + 1 835 836 newrowobj = self.__defaultRowClass(self,data_dict,joined_cols = joined_cols) 837 return_rows.append(newrowobj) 838 839 840 841 return return_rows 842 843 def __deleteRow(self,a_row,cursor = None): 844 if cursor is None: 845 cursor = self.db.defaultCursor() 846 847 # build the where clause! 848 match_spec = a_row.getPKMatchSpec() 849 sql_where_list = self.__buildWhereClause (match_spec) 850 851 sql = "delete from %s where %s" % (self.__table_name, 852 string.join(sql_where_list," and ")) 853 dlog(DEV_UPDATE,sql) 854 cursor.execute(sql) 855 856 857 def __updateRowList(self,a_row_list,cursor = None): 858 if cursor is None: 859 cursor = self.db.defaultCursor() 860 861 for a_row in a_row_list: 862 update_list = a_row.changedList() 863 864 # build the set list! 865 sql_set_list = [] 866 for a_change in update_list: 867 col_name,col_val,col_inc_val = a_change 868 c_name,c_type,c_options = self.__col_def_hash[col_name] 869 870 if c_type != kIncInteger and col_val is None: 871 sql_set_list.append("%s = NULL" % c_name) 872 elif c_type == kIncInteger and col_inc_val is None: 873 sql_set_list.append("%s = 0" % c_name) 874 else: 875 if c_type == kInteger: 876 sql_set_list.append("%s = %d" % (c_name, long(col_val))) 877 elif c_type == kIncInteger: 878 sql_set_list.append("%s = %s + %d" % (c_name,c_name,long(col_inc_val))) 879 elif c_type == kBigString and c_options.get("compress_ok",0) and self.db.compression_enabled: 880 compressed_data = zlib.compress(col_val,9) 881 if len(compressed_data) < len(col_val): 882 sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(compressed_data))) 883 else: 884 sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val))) 885 elif c_type == kReal: 886 sql_set_list.append("%s = %s" % (c_name,float(col_val))) 887 888 else: 889 sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val))) 890 891 # build the where clause! 892 match_spec = a_row.getPKMatchSpec() 893 sql_where_list = self.__buildWhereClause (match_spec) 894 895 if sql_set_list: 896 sql = "update %s set %s where %s" % (self.__table_name, 897 string.join(sql_set_list,","), 898 string.join(sql_where_list," and ")) 899 900 dlog(DEV_UPDATE,sql) 901 try: 902 cursor.execute(sql) 903 except Exception, reason: 904 if string.find(str(reason), "Duplicate entry") != -1: 905 raise eDuplicateKey, reason 906 raise Exception, reason 907 a_row.markClean() 908 909 def __insertRow(self,a_row_obj,cursor = None,replace=0): 910 if cursor is None: 911 cursor = self.db.defaultCursor() 912 913 sql_col_list = [] 914 sql_data_list = [] 915 auto_increment_column_name = None 916 917 for a_col in self.__column_list: 918 name,type,options = a_col 919 920 try: 921 data = a_row_obj[name] 922 923 sql_col_list.append(name) 924 if data is None: 925 sql_data_list.append("NULL") 926 else: 927 if type == kInteger or type == kIncInteger: 928 sql_data_list.append("%d" % data) 929 elif type == kBigString and options.get("compress_ok",0) and self.db.compression_enabled: 930 compressed_data = zlib.compress(data,9) 931 if len(compressed_data) < len(data): 932 sql_data_list.append("'%s'" % self.db.escape(compressed_data)) 933 else: 934 sql_data_list.append("'%s'" % self.db.escape(data)) 935 elif type == kReal: 936 sql_data_list.append("%s" % data) 937 else: 938 sql_data_list.append("'%s'" % self.db.escape(data)) 939 940 except KeyError: 941 if options.has_key("autoincrement"): 942 if auto_increment_column_name: 943 raise eInternalError, "two autoincrement columns (%s,%s) in table (%s)" % (auto_increment_column_name, name,self.__table_name) 944 else: 945 auto_increment_column_name = name 946 947 if replace: 948 sql = "replace into %s (%s) values (%s)" % (self.__table_name, 949 string.join(sql_col_list,","), 950 string.join(sql_data_list,",")) 951 else: 952 sql = "insert into %s (%s) values (%s)" % (self.__table_name, 953 string.join(sql_col_list,","), 954 string.join(sql_data_list,",")) 955 956 dlog(DEV_UPDATE,sql) 957 try: 958 cursor.execute(sql) 959 except Exception, reason: 960 # sys.stderr.write("errror in statement: " + sql + "\n") 961 log("error in statement: " + sql + "\n") 962 if string.find(str(reason), "Duplicate entry") != -1: 963 raise eDuplicateKey, reason 964 raise Exception, reason 965 966 if auto_increment_column_name: 967 if cursor.__module__ == "sqlite.main": 968 a_row_obj[auto_increment_column_name] = cursor.lastrowid 969 elif cursor.__module__ == "MySQLdb.cursors": 970 a_row_obj[auto_increment_column_name] = cursor.insert_id() 971 else: 972 # fallback to acting like mysql 973 a_row_obj[auto_increment_column_name] = cursor.insert_id() 974 975 # ---------------------------------------------------- 976 # Helper methods for Rows... 977 # ---------------------------------------------------- 978 979 980 981 ##################### 982 # r_deleteRow(a_row_obj,cursor = None) 983 # 984 # normally this is called from within the Row "delete()" method 985 # but you can call it yourself if you want 986 # 987 988 def r_deleteRow(self,a_row_obj, cursor = None): 989 curs = cursor 990 self.__deleteRow(a_row_obj, cursor = curs) 991 992 993 ##################### 994 # r_updateRow(a_row_obj,cursor = None) 995 # 996 # normally this is called from within the Row "save()" method 997 # but you can call it yourself if you want 998 # 999 1000 def r_updateRow(self,a_row_obj, cursor = None): 1001 curs = cursor 1002 self.__updateRowList([a_row_obj], cursor = curs) 1003 1004 ##################### 1005 # InsertRow(a_row_obj,cursor = None) 1006 # 1007 # normally this is called from within the Row "save()" method 1008 # but you can call it yourself if you want 1009 # 1010 1011 def r_insertRow(self,a_row_obj, cursor = None,replace=0): 1012 curs = cursor 1013 self.__insertRow(a_row_obj, cursor = curs,replace=replace) 1014 1015 1016 # ---------------------------------------------------- 1017 # Public Methods 1018 # ---------------------------------------------------- 1019 1020 1021 1022 ##################### 1023 # deleteRow(col_match_spec) 1024 # 1025 # The col_match_spec paramaters must include all primary key columns. 1026 # 1027 # Ex: 1028 # a_row = tbl.fetchRow( ("order_id", 1) ) 1029 # a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] ) 1030 1031 1032 def deleteRow(self,col_match_spec, where=None): 1033 n_match_spec = self._fixColMatchSpec(col_match_spec) 1034 cursor = self.db.defaultCursor() 1035 1036 # build sql where clause elements 1037 sql_where_list = self.__buildWhereClause (n_match_spec,where) 1038 if not sql_where_list: 1039 return 1040 1041 sql = "delete from %s where %s" % (self.__table_name, string.join(sql_where_list," and ")) 1042 1043 dlog(DEV_UPDATE,sql) 1044 cursor.execute(sql) 1045 1046 ##################### 1047 # fetchRow(col_match_spec) 1048 # 1049 # The col_match_spec paramaters must include all primary key columns. 1050 # 1051 # Ex: 1052 # a_row = tbl.fetchRow( ("order_id", 1) ) 1053 # a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] ) 1054 1055 1056 def fetchRow(self, col_match_spec, cursor = None): 1057 n_match_spec = self._fixColMatchSpec(col_match_spec, should_match_unique_row = 1) 1058 1059 rows = self.__fetchRows(n_match_spec, cursor = cursor) 1060 if len(rows) == 0: 1061 raise eNoMatchingRows, "no row matches %s" % repr(n_match_spec) 1062 1063 if len(rows) > 1: 1064 raise eInternalError, "unique where clause shouldn't return > 1 row" 1065 1066 return rows[0] 1067 1068 1069 ##################### 1070 # fetchRows(col_match_spec) 1071 # 1072 # Ex: 1073 # a_row_list = tbl.fetchRows( ("order_id", 1) ) 1074 # a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] ) 1075 1076 1077 def fetchRows(self, col_match_spec = None, cursor = None, 1078 where = None, order_by = None, limit_to = None, 1079 skip_to = None, join = None): 1080 n_match_spec = self._fixColMatchSpec(col_match_spec) 1081 1082 return self.__fetchRows(n_match_spec, 1083 cursor = cursor, 1084 where = where, 1085 order_by = order_by, 1086 limit_to = limit_to, 1087 skip_to = skip_to, 1088 join = join) 1089 1090 def fetchRowCount (self, col_match_spec = None, 1091 cursor = None, where = None): 1092 n_match_spec = self._fixColMatchSpec(col_match_spec) 1093 sql_where_list = self.__buildWhereClause (n_match_spec,where) 1094 sql = "select count(*) from %s" % self.__table_name 1095 if sql_where_list: 1096 sql = "%s where %s" % (sql,string.join(sql_where_list," and ")) 1097 if cursor is None: 1098 cursor = self.db.defaultCursor() 1099 dlog(DEV_SELECT,sql) 1100 cursor.execute(sql) 1101 try: 1102 count, = cursor.fetchone() 1103 except TypeError: 1104 count = 0 1105 return count 1106 1107 1108 ##################### 1109 # fetchAllRows() 1110 # 1111 # Ex: 1112 # a_row_list = tbl.fetchRows( ("order_id", 1) ) 1113 # a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] ) 1114 1115 def fetchAllRows(self): 1116 try: 1117 return self.__fetchRows([]) 1118 except eNoMatchingRows: 1119 # else return empty list... 1120 return self.__defaultRowListClass() 1121 1122 def newRow(self,replace=0): 1123 row = self.__defaultRowClass(self,None,create=1,replace=replace) 1124 for (cname, ctype, opts) in self.__column_list: 1125 if opts['default'] is not None and ctype is not kIncInteger: 1126 row[cname] = opts['default'] 1127 return row 1128 1129class Row: 1130 __instance_data_locked = 0 1131 def subclassinit(self): 1132 pass 1133 def __init__(self,_table,data_dict,create=0,joined_cols = None,replace=0): 1134 1135 self._inside_getattr = 0 # stop recursive __getattr__ 1136 self._table = _table 1137 self._should_insert = create or replace 1138 self._should_replace = replace 1139 self._rowInactive = None 1140 self._joinedRows = [] 1141 1142 self.__pk_match_spec = None 1143 self.__vcoldata = {} 1144 self.__inc_coldata = {} 1145 1146 self.__joined_cols_dict = {} 1147 for a_col in joined_cols or []: 1148 self.__joined_cols_dict[a_col] = 1 1149 1150 if create: 1151 self.__coldata = {} 1152 else: 1153 if type(data_dict) != type({}): 1154 raise eInternalError, "rowdict instantiate with bad data_dict" 1155 self.__coldata = data_dict 1156 self.__unpackVColumn() 1157 1158 self.markClean() 1159 1160 self.subclassinit() 1161 self.__instance_data_locked = 1 1162 1163 def joinRowData(self,another_row): 1164 self._joinedRows.append(another_row) 1165 1166 def getPKMatchSpec(self): 1167 return self.__pk_match_spec 1168 1169 def markClean(self): 1170 self.__vcolchanged = 0 1171 self.__colchanged_dict = {} 1172 1173 for key in self.__inc_coldata.keys(): 1174 self.__coldata[key] = self.__coldata.get(key, 0) + self.__inc_coldata[key] 1175 1176 self.__inc_coldata = {} 1177 1178 if not self._should_insert: 1179 # rebuild primary column match spec 1180 new_match_spec = [] 1181 for col_name in self._table.getPrimaryKeyList(): 1182 try: 1183 rdata = self[col_name] 1184 except KeyError: 1185 raise eInternalError, "must have primary key data filled in to save %s:Row(col:%s)" % (self._table.getTableName(),col_name) 1186 1187 new_match_spec.append( (col_name, rdata) ) 1188 self.__pk_match_spec = new_match_spec 1189 1190 def __unpackVColumn(self): 1191 if self._table.hasValueColumn(): 1192 pass 1193 1194 def __packVColumn(self): 1195 if self._table.hasValueColumn(): 1196 pass 1197 1198 ## ----- utility stuff ---------------------------------- 1199 1200 def __del__(self): 1201 # check for unsaved changes 1202 changed_list = self.changedList() 1203 if len(changed_list): 1204 info = "unsaved Row for table (%s) lost, call discard() to avoid this error. Lost changes: %s\n" % (self._table.getTableName(), repr(changed_list)[:256]) 1205 if 0: 1206 raise eUnsavedObjectLost, info 1207 else: 1208 sys.stderr.write(info) 1209 1210 1211 def __repr__(self): 1212 return "Row from (%s): %s" % (self._table.getTableName(),repr(self.__coldata) + repr(self.__vcoldata)) 1213 1214 ## ---- class emulation -------------------------------- 1215 1216 def __getattr__(self,key): 1217 if self._inside_getattr: 1218 raise AttributeError, "recursively called __getattr__ (%s,%s)" % (key,self._table.getTableName()) 1219 try: 1220 self._inside_getattr = 1 1221 try: 1222 return self[key] 1223 except KeyError: 1224 if self._table.hasColumn(key) or self._table.hasVColumn(key): 1225 return None 1226 else: 1227 raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName()) 1228 finally: 1229 self._inside_getattr = 0 1230 1231 def __setattr__(self,key,val): 1232 if not self.__instance_data_locked: 1233 self.__dict__[key] = val 1234 else: 1235 my_dict = self.__dict__ 1236 if my_dict.has_key(key): 1237 my_dict[key] = val 1238 else: 1239 # try and put it into the rowdata 1240 try: 1241 self[key] = val 1242 except KeyError, reason: 1243 raise AttributeError, reason 1244 1245 1246 ## ---- dict emulation --------------------------------- 1247 1248 def __getitem__(self,key): 1249 self.checkRowActive() 1250 1251 try: 1252 c_type = self._table.columnType(key) 1253 except eNoSuchColumn, reason: 1254 # Ugh, this sucks, we can't determine the type for a joined 1255 # row, so we just default to kVarString and let the code below 1256 # determine if this is a joined column or not 1257 c_type = kVarString 1258 1259 if c_type == kIncInteger: 1260 c_data = self.__coldata.get(key, 0) 1261 if c_data is None: c_data = 0 1262 i_data = self.__inc_coldata.get(key, 0) 1263 if i_data is None: i_data = 0 1264 return c_data + i_data 1265 1266 try: 1267 return self.__coldata[key] 1268 except KeyError: 1269 try: 1270 return self.__vcoldata[key] 1271 except KeyError: 1272 for a_joined_row in self._joinedRows: 1273 try: 1274 return a_joined_row[key] 1275 except KeyError: 1276 pass 1277 1278 raise KeyError, "unknown column %s in %s" % (key,self) 1279 1280 def __setitem__(self,key,data): 1281 self.checkRowActive() 1282 1283 try: 1284 newdata = self._table.convertDataForColumn(data,key) 1285 except eNoSuchColumn, reason: 1286 raise KeyError, reason 1287 1288 if self._table.hasColumn(key): 1289 self.__coldata[key] = newdata 1290 self.__colchanged_dict[key] = 1 1291 elif self._table.hasVColumn(key): 1292 self.__vcoldata[key] = newdata 1293 self.__vcolchanged = 1 1294 else: 1295 for a_joined_row in self._joinedRows: 1296 try: 1297 a_joined_row[key] = data 1298 return 1299 except KeyError: 1300 pass 1301 raise KeyError, "unknown column name %s" % key 1302 1303 def __delitem__(self,key,data): 1304 self.checkRowActive() 1305 1306 if self.table.hasVColumn(key): 1307 del self.__vcoldata[key] 1308 else: 1309 for a_joined_row in self._joinedRows: 1310 try: 1311 del a_joined_row[key] 1312 return 1313 except KeyError: 1314 pass 1315 raise KeyError, "unknown column name %s" % key 1316 1317 1318 def copyFrom(self,source): 1319 for name,t,options in self._table.getColumnList(): 1320 if not options.has_key("autoincrement"): 1321 self[name] = source[name] 1322 1323 1324 # make sure that .keys(), and .items() come out in a nice order! 1325 1326 def keys(self): 1327 self.checkRowActive() 1328 1329 key_list = [] 1330 for name,t,options in self._table.getColumnList(): 1331 key_list.append(name) 1332 for name in self.__joined_cols_dict.keys(): 1333 key_list.append(name) 1334 1335 for a_joined_row in self._joinedRows: 1336 key_list = key_list + a_joined_row.keys() 1337 1338 return key_list 1339 1340 1341 def items(self): 1342 self.checkRowActive() 1343 1344 item_list = [] 1345 for name,t,options in self._table.getColumnList(): 1346 item_list.append( (name,self[name]) ) 1347 for name in self.__joined_cols_dict.keys(): 1348 item_list.append( (name,self[name]) ) 1349 1350 for a_joined_row in self._joinedRows: 1351 item_list = item_list + a_joined_row.items() 1352 1353 1354 return item_list 1355 1356 def values(elf): 1357 self.checkRowActive() 1358 1359 value_list = self.__coldata.values() + self.__vcoldata.values() 1360 1361 for a_joined_row in self._joinedRows: 1362 value_list = value_list + a_joined_row.values() 1363 1364 return value_list 1365 1366 1367 def __len__(self): 1368 self.checkRowActive() 1369 1370 my_len = len(self.__coldata) + len(self.__vcoldata) 1371 1372 for a_joined_row in self._joinedRows: 1373 my_len = my_len + len(a_joined_row) 1374 1375 return my_len 1376 1377 def has_key(self,key): 1378 self.checkRowActive() 1379 1380 if self.__coldata.has_key(key) or self.__vcoldata.has_key(key): 1381 return 1 1382 else: 1383 1384 for a_joined_row in self._joinedRows: 1385 if a_joined_row.has_key(key): 1386 return 1 1387 return 0 1388 1389 def get(self,key,default = None): 1390 self.checkRowActive() 1391 1392 1393 1394 if self.__coldata.has_key(key): 1395 return self.__coldata[key] 1396 elif self.__vcoldata.has_key(key): 1397 return self.__vcoldata[key] 1398 else: 1399 for a_joined_row in self._joinedRows: 1400 try: 1401 return a_joined_row.get(key,default) 1402 except eNoSuchColumn: 1403 pass 1404 1405 if self._table.hasColumn(key): 1406 return default 1407 1408 raise eNoSuchColumn, "no such column %s" % key 1409 1410 def inc(self,key,count=1): 1411 self.checkRowActive() 1412 1413 if self._table.hasColumn(key): 1414 try: 1415 self.__inc_coldata[key] = self.__inc_coldata[key] + count 1416 except KeyError: 1417 self.__inc_coldata[key] = count 1418 1419 self.__colchanged_dict[key] = 1 1420 else: 1421 raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName()) 1422 1423 1424 ## ---------------------------------- 1425 ## real interface 1426 1427 1428 def fillDefaults(self): 1429 for field_def in self._table.fieldList(): 1430 name,type,size,options = field_def 1431 if options.has_key("default"): 1432 self[name] = options["default"] 1433 1434 ############### 1435 # changedList() 1436 # 1437 # returns a list of tuples for the columns which have changed 1438 # 1439 # changedList() -> [ ('name', 'fred'), ('age', 20) ] 1440 1441 def changedList(self): 1442 if self.__vcolchanged: 1443 self.__packVColumn() 1444 1445 changed_list = [] 1446 for a_col in self.__colchanged_dict.keys(): 1447 changed_list.append( (a_col,self.get(a_col,None),self.__inc_coldata.get(a_col,None)) ) 1448 1449 return changed_list 1450 1451 def discard(self): 1452 self.__coldata = None 1453 self.__vcoldata = None 1454 self.__colchanged_dict = {} 1455 self.__vcolchanged = 0 1456 1457 def delete(self,cursor = None): 1458 self.checkRowActive() 1459 1460 1461 fromTable = self._table 1462 curs = cursor 1463 fromTable.r_deleteRow(self,cursor=curs) 1464 self._rowInactive = "deleted" 1465 1466 def save(self,cursor = None): 1467 toTable = self._table 1468 1469 self.checkRowActive() 1470 1471 if self._should_insert: 1472 toTable.r_insertRow(self,replace=self._should_replace) 1473 self._should_insert = 0 1474 self._should_replace = 0 1475 self.markClean() # rebuild the primary key list 1476 else: 1477 curs = cursor 1478 toTable.r_updateRow(self,cursor = curs) 1479 1480 # the table will mark us clean! 1481 # self.markClean() 1482 1483 def checkRowActive(self): 1484 if self._rowInactive: 1485 raise eInvalidData, "row is inactive: %s" % self._rowInactive 1486 1487 def databaseSizeForColumn(self,key): 1488 return self._table.databaseSizeForData_ColumnName_(self[key],key) 1489 1490 1491if __name__ == "__main__": 1492 print "run odb_test.py" 1493