1#!/usr/bin/env python
2#
3# odb.py
4#
5# Object Database Api
6#
7# Written by David Jeske <jeske@neotonic.com>, 2001/07.
8# Inspired by eGroups' sqldb.py originally written by Scott Hassan circa 1998.
9#
10# Copyright (C) 2001, by David Jeske and Neotonic
11#
12# Goals:
13#       - a simple object-like interface to database data
14#       - database independent (someday)
15#       - relational-style "rigid schema definition"
16#       - object style easy-access
17#
18# Example:
19#
20#  import odb
21#
22#  # define table
23#  class AgentsTable(odb.Table):
24#    def _defineRows(self):
25#      self.d_addColumn("agent_id",kInteger,None,primarykey = 1,autoincrement = 1)
26#      self.d_addColumn("login",kVarString,200,notnull=1)
27#      self.d_addColumn("ticket_count",kIncInteger,None)
28#
29#  if __name__ == "__main__":
30#    # open database
31#    ndb = MySQLdb.connect(host = 'localhost',
32#                          user='username',
33#                          passwd = 'password',
34#                          db='testdb')
35#    db = Database(ndb)
36#    tbl = AgentsTable(db,"agents")
37#
38#    # create row
39#    agent_row = tbl.newRow()
40#    agent_row.login = "foo"
41#    agent_row.save()
42#
43#    # fetch row (must use primary key)
44#    try:
45#      get_row = tbl.fetchRow( ('agent_id', agent_row.agent_id) )
46#    except odb.eNoMatchingRows:
47#      print "this is bad, we should have found the row"
48#
49#    # fetch rows (can return empty list)
50#    list_rows = tbl.fetchRows( ('login', "foo") )
51#
52
53import string
54import sys, zlib
55from log import *
56
57import handle_error
58
59eNoSuchColumn         = "odb.eNoSuchColumn"
60eNonUniqueMatchSpec   = "odb.eNonUniqueMatchSpec"
61eNoMatchingRows       = "odb.eNoMatchingRows"
62eInternalError        = "odb.eInternalError"
63eInvalidMatchSpec     = "odb.eInvalidMatchSpec"
64eInvalidData          = "odb.eInvalidData"
65eUnsavedObjectLost    = "odb.eUnsavedObjectLost"
66eDuplicateKey         = "odb.eDuplicateKey"
67
68#####################################
69# COLUMN TYPES
70################                     ######################
71# typename     ####################### size data means:
72#              #                     #
73kInteger       = "kInteger"          # -
74kFixedString   = "kFixedString"      # size
75kVarString     = "kVarString"        # maxsize
76kBigString     = "kBigString"        # -
77kIncInteger    = "kIncInteger"       # -
78kDateTime      = "kDateTime"
79kTimeStamp     = "kTimeStamp"
80kReal          = "kReal"
81
82
83DEBUG = 0
84
85##############
86# Database
87#
88# this will ultimately turn into a mostly abstract base class for
89# the DB adaptors for different database types....
90#
91
92class Database:
93    def __init__(self, db, debug=0):
94        self._tables = {}
95        self.db = db
96        self._cursor = None
97        self.compression_enabled = 0
98        self.debug = debug
99        self.SQLError = None
100
101	self.__defaultRowClass = self.defaultRowClass()
102	self.__defaultRowListClass = self.defaultRowListClass()
103
104    def defaultCursor(self):
105        if self._cursor is None:
106            self._cursor = self.db.cursor()
107        return self._cursor
108
109    def escape(self,str):
110	raise "Unimplemented Error"
111
112    def getDefaultRowClass(self): return self.__defaultRowClass
113    def setDefaultRowClass(self, clss): self.__defaultRowClass = clss
114    def getDefaultRowListClass(self): return self.__defaultRowListClass
115    def setDefaultRowListClass(self, clss): self.__defaultRowListClass = clss
116
117    def defaultRowClass(self):
118        return Row
119
120    def defaultRowListClass(self):
121        # base type is list...
122        return list
123
124    def addTable(self, attrname, tblname, tblclass,
125                 rowClass = None, check = 0, create = 0, rowListClass = None):
126        tbl = tblclass(self, tblname, rowClass=rowClass, check=check,
127                       create=create, rowListClass=rowListClass)
128        self._tables[attrname] = tbl
129        return tbl
130
131    def close(self):
132        for name, tbl in self._tables.items():
133            tbl.db = None
134        self._tables = {}
135        if self.db is not None:
136            self.db.close()
137            self.db = None
138
139    def __getattr__(self, key):
140        if key == "_tables":
141            raise AttributeError, "odb.Database: not initialized properly, self._tables does not exist"
142
143        try:
144            table_dict = getattr(self,"_tables")
145            return table_dict[key]
146        except KeyError:
147            raise AttributeError, "odb.Database: unknown attribute %s" % (key)
148
149    def beginTransaction(self, cursor=None):
150        if cursor is None:
151            cursor = self.defaultCursor()
152        dlog(DEV_UPDATE,"begin")
153        cursor.execute("begin")
154
155    def commitTransaction(self, cursor=None):
156        if cursor is None:
157            cursor = self.defaultCursor()
158        dlog(DEV_UPDATE,"commit")
159        cursor.execute("commit")
160
161    def rollbackTransaction(self, cursor=None):
162        if cursor is None:
163            cursor = self.defaultCursor()
164        dlog(DEV_UPDATE,"rollback")
165        cursor.execute("rollback")
166
167    ##
168    ## schema creation code
169    ##
170
171    def createTables(self):
172      tables = self.listTables()
173
174      for attrname, tbl in self._tables.items():
175        tblname = tbl.getTableName()
176
177        if tblname not in tables:
178          print "table %s does not exist" % tblname
179          tbl.createTable()
180        else:
181          invalidAppCols, invalidDBCols = tbl.checkTable()
182
183##          self.alterTableToMatch(tbl)
184
185    def createIndices(self):
186      indices = self.listIndices()
187
188      for attrname, tbl in self._tables.items():
189        for indexName, (columns, unique) in tbl.getIndices().items():
190          if indexName in indices: continue
191
192          tbl.createIndex(columns, indexName=indexName, unique=unique)
193
194    def synchronizeSchema(self):
195      tables = self.listTables()
196
197      for attrname, tbl in self._tables.items():
198        tblname = tbl.getTableName()
199        self.alterTableToMatch(tbl)
200
201    def listTables(self, cursor=None):
202      raise "Unimplemented Error"
203
204    def listFieldsDict(self, table_name, cursor=None):
205      raise "Unimplemented Error"
206
207    def listFields(self, table_name, cursor=None):
208      columns = self.listFieldsDict(table_name, cursor=cursor)
209      return columns.keys()
210
211##########################################
212# Table
213#
214
215
216class Table:
217    def subclassinit(self):
218        pass
219    def __init__(self,database,table_name,
220                 rowClass = None, check = 0, create = 0, rowListClass = None):
221        self.db = database
222        self.__table_name = table_name
223        if rowClass:
224            self.__defaultRowClass = rowClass
225        else:
226            self.__defaultRowClass = database.getDefaultRowClass()
227
228        if rowListClass:
229            self.__defaultRowListClass = rowListClass
230        else:
231            self.__defaultRowListClass = database.getDefaultRowListClass()
232
233        # get this stuff ready!
234
235        self.__column_list = []
236        self.__vcolumn_list = []
237        self.__columns_locked = 0
238        self.__has_value_column = 0
239
240        self.__indices = {}
241
242        # this will be used during init...
243        self.__col_def_hash = None
244        self.__vcol_def_hash = None
245        self.__primary_key_list = None
246        self.__relations_by_table = {}
247
248        # ask the subclass to def his rows
249        self._defineRows()
250
251        # get ready to run!
252        self.__lockColumnsAndInit()
253
254        self.subclassinit()
255
256        if create:
257            self.createTable()
258
259        if check:
260            self.checkTable()
261
262    def _colTypeToSQLType(self, colname, coltype, options):
263
264      if coltype == kInteger:
265        coltype = "integer"
266      elif coltype == kFixedString:
267        sz = options.get('size', None)
268        if sz is None: coltype = 'char'
269        else:  coltype = "char(%s)" % sz
270      elif coltype == kVarString:
271        sz = options.get('size', None)
272        if sz is None: coltype = 'varchar'
273        else:  coltype = "varchar(%s)" % sz
274      elif coltype == kBigString:
275        coltype = "text"
276      elif coltype == kIncInteger:
277        coltype = "integer"
278      elif coltype == kDateTime:
279        coltype = "datetime"
280      elif coltype == kTimeStamp:
281        coltype = "timestamp"
282      elif coltype == kReal:
283        coltype = "real"
284
285      coldef = "%s %s" % (colname, coltype)
286
287      if options.get('notnull', 0): coldef = coldef + " NOT NULL"
288      if options.get('autoincrement', 0): coldef = coldef + " AUTO_INCREMENT"
289      if options.get('unique', 0): coldef = coldef + " UNIQUE"
290#      if options.get('primarykey', 0): coldef = coldef + " primary key"
291      if options.get('default', None) is not None: coldef = coldef + " DEFAULT %s" % options.get('default')
292
293      return coldef
294
295    def getTableName(self):  return self.__table_name
296    def setTableName(self, tablename):  self.__table_name = tablename
297
298    def getIndices(self): return self.__indices
299
300    def _createTableSQL(self):
301      defs = []
302      for colname, coltype, options in self.__column_list:
303        defs.append(self._colTypeToSQLType(colname, coltype, options))
304
305      defs = string.join(defs, ", ")
306
307      primarykeys = self.getPrimaryKeyList()
308      primarykey_str = ""
309      if primarykeys:
310	primarykey_str = ", PRIMARY KEY (" + string.join(primarykeys, ",") + ")"
311
312      sql = "create table %s (%s %s)" % (self.__table_name, defs, primarykey_str)
313      return sql
314
315    def createTable(self, cursor=None):
316      if cursor is None: cursor = self.db.defaultCursor()
317      sql = self._createTableSQL()
318      print "CREATING TABLE:", sql
319      cursor.execute(sql)
320
321    def dropTable(self, cursor=None):
322      if cursor is None: cursor = self.db.defaultCursor()
323      try:
324        cursor.execute("drop table %s" % self.__table_name)   # clean out the table
325      except self.SQLError, reason:
326        pass
327
328    def renameTable(self, newTableName, cursor=None):
329      if cursor is None: cursor = self.db.defaultCursor()
330      try:
331        cursor.execute("rename table %s to %s" % (self.__table_name, newTableName))
332      except sel.SQLError, reason:
333        pass
334
335      self.setTableName(newTableName)
336
337    def getTableColumnsFromDB(self):
338      return self.db.listFieldsDict(self.__table_name)
339
340    def checkTable(self, warnflag=1):
341      invalidDBCols = {}
342      invalidAppCols = {}
343
344      dbcolumns = self.getTableColumnsFromDB()
345      for coldef in self.__column_list:
346        colname = coldef[0]
347
348        dbcoldef = dbcolumns.get(colname, None)
349        if dbcoldef is None:
350          invalidAppCols[colname] = 1
351
352      for colname, row in dbcolumns.items():
353        coldef = self.__col_def_hash.get(colname, None)
354        if coldef is None:
355          invalidDBCols[colname] = 1
356
357      if warnflag == 1:
358        if invalidDBCols:
359          print "----- WARNING ------------------------------------------"
360          print "  There are columns defined in the database schema that do"
361          print "  not match the application's schema."
362          print "  columns:", invalidDBCols.keys()
363          print "--------------------------------------------------------"
364
365        if invalidAppCols:
366          print "----- WARNING ------------------------------------------"
367          print "  There are new columns defined in the application schema"
368          print "  that do not match the database's schema."
369          print "  columns:", invalidAppCols.keys()
370          print "--------------------------------------------------------"
371
372      return invalidAppCols, invalidDBCols
373
374
375    def alterTableToMatch(self):
376      raise "Unimplemented Error!"
377
378    def addIndex(self, columns, indexName=None, unique=0):
379      if indexName is None:
380        indexName = self.getTableName() + "_index_" + string.join(columns, "_")
381
382      self.__indices[indexName] = (columns, unique)
383
384    def createIndex(self, columns, indexName=None, unique=0, cursor=None):
385      if cursor is None: cursor = self.db.defaultCursor()
386      cols = string.join(columns, ",")
387
388      if indexName is None:
389        indexName = self.getTableName() + "_index_" + string.join(columns, "_")
390
391      uniquesql = ""
392      if unique:
393        uniquesql = " unique"
394      sql = "create %s index %s on %s (%s)" % (uniquesql, indexName, self.getTableName(), cols)
395      warn("creating index", sql)
396      cursor.execute(sql)
397
398
399    ## Column Definition
400
401    def getColumnDef(self,column_name):
402        try:
403            return self.__col_def_hash[column_name]
404        except KeyError:
405            try:
406                return self.__vcol_def_hash[column_name]
407            except KeyError:
408                raise eNoSuchColumn, "no column (%s) on table %s" % (column_name,self.__table_name)
409
410    def getColumnList(self):
411      return self.__column_list + self.__vcolumn_list
412    def getAppColumnList(self): return self.__column_list
413
414    def databaseSizeForData_ColumnName_(self,data,col_name):
415        try:
416            col_def = self.__col_def_hash[col_name]
417        except KeyError:
418            try:
419                col_def = self.__vcol_def_hash[col_name]
420            except KeyError:
421                raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)
422
423        c_name,c_type,c_options = col_def
424
425        if c_type == kBigString:
426            if c_options.get("compress_ok",0) and self.db.compression_enabled:
427                z_size = len(zlib.compress(data,9))
428                r_size = len(data)
429                if z_size < r_size:
430                    return z_size
431                else:
432                    return r_size
433            else:
434                return len(data)
435        else:
436            # really simplistic database size computation:
437            try:
438                a = data[0]
439                return len(data)
440            except:
441                return 4
442
443
444    def columnType(self, col_name):
445        try:
446            col_def = self.__col_def_hash[col_name]
447        except KeyError:
448            try:
449                col_def = self.__vcol_def_hash[col_name]
450            except KeyError:
451                raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)
452
453        c_name,c_type,c_options = col_def
454        return c_type
455
456    def convertDataForColumn(self,data,col_name):
457        try:
458            col_def = self.__col_def_hash[col_name]
459        except KeyError:
460            try:
461                col_def = self.__vcol_def_hash[col_name]
462            except KeyError:
463                raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)
464
465        c_name,c_type,c_options = col_def
466
467        if c_type == kIncInteger:
468            raise eInvalidData, "invalid operation for column (%s:%s) on table (%s)" % (col_name,c_type,self.__table_name)
469
470        if c_type == kInteger:
471            try:
472                if data is None: data = 0
473                else: return long(data)
474            except (ValueError,TypeError):
475                raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data),col_name,c_type,self.__table_name)
476        elif c_type == kReal:
477            try:
478                if data is None: data = 0.0
479                else: return float(data)
480            except (ValueError,TypeError):
481                raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data), col_name,c_type,self.__table_name)
482
483        else:
484            if type(data) == type(long(0)):
485                return "%d" % data
486            else:
487                return str(data)
488
489    def getPrimaryKeyList(self):
490        return self.__primary_key_list
491
492    def hasValueColumn(self):
493        return self.__has_value_column
494
495    def hasColumn(self,name):
496        return self.__col_def_hash.has_key(name)
497    def hasVColumn(self,name):
498        return self.__vcol_def_hash.has_key(name)
499
500
501    def _defineRows(self):
502        raise "can't instantiate base odb.Table type, make a subclass and override _defineRows()"
503
504    def __lockColumnsAndInit(self):
505        # add a 'odb_value column' before we lockdown the table def
506        if self.__has_value_column:
507            self.d_addColumn("odb_value",kBigText,default='')
508
509        self.__columns_locked = 1
510        # walk column list and make lookup hashes, primary_key_list, etc..
511
512        primary_key_list = []
513        col_def_hash = {}
514        for a_col in self.__column_list:
515            name,type,options = a_col
516            col_def_hash[name] = a_col
517            if options.has_key('primarykey'):
518                primary_key_list.append(name)
519
520        self.__col_def_hash = col_def_hash
521        self.__primary_key_list = primary_key_list
522
523        # setup the value columns!
524
525        if (not self.__has_value_column) and (len(self.__vcolumn_list) > 0):
526            raise "can't define vcolumns on table without ValueColumn, call d_addValueColumn() in your _defineRows()"
527
528        vcol_def_hash = {}
529        for a_col in self.__vcolumn_list:
530            name,type,size_data,options = a_col
531            vcol_def_hash[name] = a_col
532
533        self.__vcol_def_hash = vcol_def_hash
534
535
536    def __checkColumnLock(self):
537        if self.__columns_locked:
538            raise "can't change column definitions outside of subclass' _defineRows() method!"
539
540    # table definition methods, these are only available while inside the
541    # subclass's _defineRows method
542    #
543    # Ex:
544    #
545    # import odb
546    # class MyTable(odb.Table):
547    #   def _defineRows(self):
548    #     self.d_addColumn("id",kInteger,primarykey = 1,autoincrement = 1)
549    #     self.d_addColumn("name",kVarString,120)
550    #     self.d_addColumn("type",kInteger,
551    #                      enum_values = { 0 : "alive", 1 : "dead" }
552
553    def d_addColumn(self,col_name,ctype,size=None,primarykey = 0,
554                    notnull = 0,indexed=0,
555                    default=None,unique=0,autoincrement=0,safeupdate=0,
556                    enum_values = None,
557		    no_export = 0,
558                    relations=None,compress_ok=0,int_date=0):
559
560        self.__checkColumnLock()
561
562        options = {}
563        options['default']       = default
564        if primarykey:
565            options['primarykey']    = primarykey
566        if unique:
567            options['unique']        = unique
568        if indexed:
569            options['indexed']       = indexed
570            self.addIndex((col_name,))
571        if safeupdate:
572            options['safeupdate']    = safeupdate
573        if autoincrement:
574            options['autoincrement'] = autoincrement
575        if notnull:
576            options['notnull']       = notnull
577        if size:
578            options['size']          = size
579        if no_export:
580            options['no_export']     = no_export
581        if int_date:
582            if ctype != kInteger:
583                raise eInvalidData, "can't flag columns int_date unless they are kInteger"
584            else:
585                options['int_date'] = int_date
586
587        if enum_values:
588            options['enum_values']   = enum_values
589            inv_enum_values = {}
590            for k,v in enum_values.items():
591                if inv_enum_values.has_key(v):
592                    raise eInvalidData, "enum_values paramater must be a 1 to 1 mapping for Table(%s)" % self.__table_name
593                else:
594                    inv_enum_values[v] = k
595            options['inv_enum_values'] = inv_enum_values
596        if relations:
597            options['relations']      = relations
598            for a_relation in relations:
599                table, foreign_column_name = a_relation
600                if self.__relations_by_table.has_key(table):
601                    raise eInvalidData, "multiple relations for the same foreign table are not yet supported"
602                self.__relations_by_table[table] = (col_name,foreign_column_name)
603        if compress_ok:
604            if ctype == kBigString:
605                options['compress_ok'] = 1
606            else:
607                raise eInvalidData, "only kBigString fields can be compress_ok=1"
608
609        self.__column_list.append( (col_name,ctype,options) )
610
611    def d_addValueColumn(self):
612        self.__checkColumnLock()
613        self.__has_value_column = 1
614
615    def d_addVColumn(self,col_name,type,size=None,default=None):
616        self.__checkColumnLock()
617
618        if (not self.__has_value_column):
619            raise "can't define VColumns on table without ValueColumn, call d_addValueColumn() first"
620
621        options = {}
622        if default:
623            options['default'] = default
624        if size:
625            options['size']    = size
626
627        self.__vcolumn_list.append( (col_name,type,options) )
628
629    #####################
630    # _checkColMatchSpec(col_match_spec,should_match_unique_row = 0)
631    #
632    # raise an error if the col_match_spec contains invalid columns, or
633    # (in the case of should_match_unique_row) if it does not fully specify
634    # a unique row.
635    #
636    # NOTE: we don't currently support where clauses with value column fields!
637    #
638
639    def _fixColMatchSpec(self,col_match_spec, should_match_unique_row = 0):
640        if type(col_match_spec) == type([]):
641            if type(col_match_spec[0]) != type((0,)):
642                raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"
643        elif type(col_match_spec) == type((0,)):
644            col_match_spec = [ col_match_spec ]
645        elif type(col_match_spec) == type(None):
646            if should_match_unique_row:
647                raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec
648            else:
649                return None
650        else:
651            raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"
652
653        if should_match_unique_row:
654            unique_column_lists = []
655
656            # first the primary key list
657            my_primary_key_list = []
658            for a_key in self.__primary_key_list:
659                my_primary_key_list.append(a_key)
660
661            # then other unique keys
662            for a_col in self.__column_list:
663                col_name,a_type,options = a_col
664                if options.has_key('unique'):
665                    unique_column_lists.append( (col_name, [col_name]) )
666
667            unique_column_lists.append( ('primary_key', my_primary_key_list) )
668
669
670        new_col_match_spec = []
671        for a_col in col_match_spec:
672            name,val = a_col
673            # newname = string.lower(name)
674            #  what is this doing?? - jeske
675            newname = name
676            if not self.__col_def_hash.has_key(newname):
677                raise eNoSuchColumn, "no such column in match spec: '%s'" % newname
678
679            new_col_match_spec.append( (newname,val) )
680
681            if should_match_unique_row:
682                for name,a_list in unique_column_lists:
683                    try:
684                        a_list.remove(newname)
685                    except ValueError:
686                        # it's okay if they specify too many columns!
687                        pass
688
689        if should_match_unique_row:
690            for name,a_list in unique_column_lists:
691                if len(a_list) == 0:
692                    # we matched at least one unique colum spec!
693                    # log("using unique column (%s) for query %s" % (name,col_match_spec))
694                    return new_col_match_spec
695
696            raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec
697
698        return new_col_match_spec
699
700    def __buildWhereClause (self, col_match_spec,other_clauses = None):
701        sql_where_list = []
702
703        if not col_match_spec is None:
704            for m_col in col_match_spec:
705                m_col_name,m_col_val = m_col
706                c_name,c_type,c_options = self.__col_def_hash[m_col_name]
707                if c_type in (kIncInteger, kInteger):
708                    try:
709                        m_col_val_long = long(m_col_val)
710                    except ValueError:
711                        raise ValueError, "invalid literal for long(%s) in table %s" % (repr(m_col_val),self.__table_name)
712
713                    sql_where_list.append("%s = %d" % (c_name, m_col_val_long))
714                elif c_type == kReal:
715                    try:
716                        m_col_val_float = float(m_col_val)
717                    except ValueError:
718                        raise ValueError, "invalid literal for float(%s) is table %s" % (repr(m_col_val), self.__table_name)
719                    sql_where_list.append("%s = %s" % (c_name, m_col_val_float))
720                else:
721                    sql_where_list.append("%s = '%s'" % (c_name, self.db.escape(m_col_val)))
722
723        if other_clauses is None:
724            pass
725        elif type(other_clauses) == type(""):
726            sql_where_list = sql_where_list + [other_clauses]
727        elif type(other_clauses) == type([]):
728            sql_where_list = sql_where_list + other_clauses
729        else:
730            raise eInvalidData, "unknown type of extra where clause: %s" % repr(other_clauses)
731
732        return sql_where_list
733
734    def __fetchRows(self,col_match_spec,cursor = None, where = None, order_by = None, limit_to = None,
735                    skip_to = None, join = None):
736        if cursor is None:
737            cursor = self.db.defaultCursor()
738
739        # build column list
740        sql_columns = []
741        for name,t,options in self.__column_list:
742            sql_columns.append(name)
743
744        # build join information
745
746        joined_cols = []
747        joined_cols_hash = {}
748        join_clauses = []
749        if not join is None:
750            for a_table,retrieve_foreign_cols in join:
751                try:
752                    my_col,foreign_col = self.__relations_by_table[a_table]
753                    for a_col in retrieve_foreign_cols:
754                        full_col_name = "%s.%s" % (my_col,a_col)
755                        joined_cols_hash[full_col_name] = 1
756                        joined_cols.append(full_col_name)
757                        sql_columns.append( full_col_name )
758
759                    join_clauses.append(" left join %s as %s on %s=%s " % (a_table,my_col,my_col,foreign_col))
760
761                except KeyError:
762                    eInvalidJoinSpec, "can't find table %s in defined relations for %s" % (a_table,self.__table_name)
763
764        # start buildling SQL
765        sql = "select %s from %s" % (string.join(sql_columns,","),
766                                     self.__table_name)
767
768        # add join clause
769        if join_clauses:
770            sql = sql + string.join(join_clauses," ")
771
772        # add where clause elements
773        sql_where_list = self.__buildWhereClause (col_match_spec,where)
774        if sql_where_list:
775            sql = sql + " where %s" % (string.join(sql_where_list," and "))
776
777        # add order by clause
778        if order_by:
779            sql = sql + " order by %s " % string.join(order_by,",")
780
781        # add limit
782        if not limit_to is None:
783            if not skip_to is None:
784#                log("limit,skip = %s,%s" % (limit_to,skip_to))
785                if self.db.db.__module__ == "sqlite.main":
786                    sql = sql + " limit %s offset %s " % (limit_to,skip_to)
787                else:
788                    sql = sql + " limit %s, %s" % (skip_to,limit_to)
789            else:
790                sql = sql + " limit %s" % limit_to
791        else:
792            if not skip_to is None:
793                raise eInvalidData, "can't specify skip_to without limit_to in MySQL"
794
795        dlog(DEV_SELECT,sql)
796        cursor.execute(sql)
797
798        # create defaultRowListClass instance...
799        return_rows = self.__defaultRowListClass()
800
801        # should do fetchmany!
802        all_rows = cursor.fetchall()
803        for a_row in all_rows:
804            data_dict = {}
805
806            col_num = 0
807
808            #            for a_col in cursor.description:
809            #                (name,type_code,display_size,internal_size,precision,scale,null_ok) = a_col
810            for name in sql_columns:
811                if self.__col_def_hash.has_key(name) or joined_cols_hash.has_key(name):
812                    # only include declared columns!
813                    if self.__col_def_hash.has_key(name):
814                        c_name,c_type,c_options = self.__col_def_hash[name]
815                        if c_type == kBigString and c_options.get("compress_ok",0) and a_row[col_num]:
816                            try:
817                                a_col_data = zlib.decompress(a_row[col_num])
818                            except zlib.error:
819                                a_col_data = a_row[col_num]
820
821                            data_dict[name] = a_col_data
822                        elif c_type == kInteger or c_type == kIncInteger:
823                            value = a_row[col_num]
824                            if not value is None:
825                                data_dict[name] = int(value)
826                            else:
827                                data_dict[name] = None
828                        else:
829                            data_dict[name] = a_row[col_num]
830
831                    else:
832                        data_dict[name] = a_row[col_num]
833
834                    col_num = col_num + 1
835
836	    newrowobj = self.__defaultRowClass(self,data_dict,joined_cols = joined_cols)
837	    return_rows.append(newrowobj)
838
839
840
841        return return_rows
842
843    def __deleteRow(self,a_row,cursor = None):
844        if cursor is None:
845            cursor = self.db.defaultCursor()
846
847        # build the where clause!
848        match_spec = a_row.getPKMatchSpec()
849        sql_where_list = self.__buildWhereClause (match_spec)
850
851        sql = "delete from %s where %s" % (self.__table_name,
852                                           string.join(sql_where_list," and "))
853        dlog(DEV_UPDATE,sql)
854        cursor.execute(sql)
855
856
857    def __updateRowList(self,a_row_list,cursor = None):
858        if cursor is None:
859            cursor = self.db.defaultCursor()
860
861        for a_row in a_row_list:
862            update_list = a_row.changedList()
863
864            # build the set list!
865            sql_set_list = []
866            for a_change in update_list:
867                col_name,col_val,col_inc_val = a_change
868                c_name,c_type,c_options = self.__col_def_hash[col_name]
869
870                if c_type != kIncInteger and col_val is None:
871                    sql_set_list.append("%s = NULL" % c_name)
872                elif c_type == kIncInteger and col_inc_val is None:
873                    sql_set_list.append("%s = 0" % c_name)
874                else:
875                    if c_type == kInteger:
876                        sql_set_list.append("%s = %d" % (c_name, long(col_val)))
877                    elif c_type == kIncInteger:
878                        sql_set_list.append("%s = %s + %d" % (c_name,c_name,long(col_inc_val)))
879                    elif c_type == kBigString and c_options.get("compress_ok",0) and self.db.compression_enabled:
880                        compressed_data = zlib.compress(col_val,9)
881                        if len(compressed_data) < len(col_val):
882                            sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(compressed_data)))
883                        else:
884                            sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val)))
885                    elif c_type == kReal:
886                        sql_set_list.append("%s = %s" % (c_name,float(col_val)))
887
888                    else:
889                        sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val)))
890
891            # build the where clause!
892            match_spec = a_row.getPKMatchSpec()
893            sql_where_list = self.__buildWhereClause (match_spec)
894
895            if sql_set_list:
896                sql = "update %s set %s where %s" % (self.__table_name,
897                                                 string.join(sql_set_list,","),
898                                                 string.join(sql_where_list," and "))
899
900                dlog(DEV_UPDATE,sql)
901                try:
902                    cursor.execute(sql)
903                except Exception, reason:
904                    if string.find(str(reason), "Duplicate entry") != -1:
905                        raise eDuplicateKey, reason
906                    raise Exception, reason
907                a_row.markClean()
908
909    def __insertRow(self,a_row_obj,cursor = None,replace=0):
910        if cursor is None:
911            cursor = self.db.defaultCursor()
912
913        sql_col_list = []
914        sql_data_list = []
915        auto_increment_column_name = None
916
917        for a_col in self.__column_list:
918            name,type,options = a_col
919
920            try:
921                data = a_row_obj[name]
922
923                sql_col_list.append(name)
924                if data is None:
925                    sql_data_list.append("NULL")
926                else:
927                    if type == kInteger or type == kIncInteger:
928                        sql_data_list.append("%d" % data)
929                    elif type == kBigString and options.get("compress_ok",0) and self.db.compression_enabled:
930                        compressed_data = zlib.compress(data,9)
931                        if len(compressed_data) < len(data):
932                            sql_data_list.append("'%s'" % self.db.escape(compressed_data))
933                        else:
934                            sql_data_list.append("'%s'" % self.db.escape(data))
935                    elif type == kReal:
936                        sql_data_list.append("%s" % data)
937                    else:
938                        sql_data_list.append("'%s'" % self.db.escape(data))
939
940            except KeyError:
941                if options.has_key("autoincrement"):
942                    if auto_increment_column_name:
943                        raise eInternalError, "two autoincrement columns (%s,%s) in table (%s)" % (auto_increment_column_name, name,self.__table_name)
944                    else:
945                        auto_increment_column_name = name
946
947        if replace:
948            sql = "replace into %s (%s) values (%s)" % (self.__table_name,
949                                                   string.join(sql_col_list,","),
950                                                   string.join(sql_data_list,","))
951        else:
952            sql = "insert into %s (%s) values (%s)" % (self.__table_name,
953                                                   string.join(sql_col_list,","),
954                                                   string.join(sql_data_list,","))
955
956        dlog(DEV_UPDATE,sql)
957        try:
958          cursor.execute(sql)
959        except Exception, reason:
960          # sys.stderr.write("errror in statement: " + sql + "\n")
961          log("error in statement: " + sql + "\n")
962          if string.find(str(reason), "Duplicate entry") != -1:
963            raise eDuplicateKey, reason
964          raise Exception, reason
965
966        if auto_increment_column_name:
967            if cursor.__module__ == "sqlite.main":
968                a_row_obj[auto_increment_column_name] = cursor.lastrowid
969            elif cursor.__module__ == "MySQLdb.cursors":
970                a_row_obj[auto_increment_column_name] = cursor.insert_id()
971            else:
972                # fallback to acting like mysql
973                a_row_obj[auto_increment_column_name] = cursor.insert_id()
974
975    # ----------------------------------------------------
976    #   Helper methods for Rows...
977    # ----------------------------------------------------
978
979
980
981    #####################
982    # r_deleteRow(a_row_obj,cursor = None)
983    #
984    # normally this is called from within the Row "delete()" method
985    # but you can call it yourself if you want
986    #
987
988    def r_deleteRow(self,a_row_obj, cursor = None):
989        curs = cursor
990        self.__deleteRow(a_row_obj, cursor = curs)
991
992
993    #####################
994    # r_updateRow(a_row_obj,cursor = None)
995    #
996    # normally this is called from within the Row "save()" method
997    # but you can call it yourself if you want
998    #
999
1000    def r_updateRow(self,a_row_obj, cursor = None):
1001        curs = cursor
1002        self.__updateRowList([a_row_obj], cursor = curs)
1003
1004    #####################
1005    # InsertRow(a_row_obj,cursor = None)
1006    #
1007    # normally this is called from within the Row "save()" method
1008    # but you can call it yourself if you want
1009    #
1010
1011    def r_insertRow(self,a_row_obj, cursor = None,replace=0):
1012        curs = cursor
1013        self.__insertRow(a_row_obj, cursor = curs,replace=replace)
1014
1015
1016    # ----------------------------------------------------
1017    #   Public Methods
1018    # ----------------------------------------------------
1019
1020
1021
1022    #####################
1023    # deleteRow(col_match_spec)
1024    #
1025    # The col_match_spec paramaters must include all primary key columns.
1026    #
1027    # Ex:
1028    #    a_row = tbl.fetchRow( ("order_id", 1) )
1029    #    a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )
1030
1031
1032    def deleteRow(self,col_match_spec, where=None):
1033        n_match_spec = self._fixColMatchSpec(col_match_spec)
1034        cursor = self.db.defaultCursor()
1035
1036        # build sql where clause elements
1037        sql_where_list = self.__buildWhereClause (n_match_spec,where)
1038        if not sql_where_list:
1039            return
1040
1041        sql = "delete from %s where %s" % (self.__table_name, string.join(sql_where_list," and "))
1042
1043        dlog(DEV_UPDATE,sql)
1044        cursor.execute(sql)
1045
1046    #####################
1047    # fetchRow(col_match_spec)
1048    #
1049    # The col_match_spec paramaters must include all primary key columns.
1050    #
1051    # Ex:
1052    #    a_row = tbl.fetchRow( ("order_id", 1) )
1053    #    a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )
1054
1055
1056    def fetchRow(self, col_match_spec, cursor = None):
1057        n_match_spec = self._fixColMatchSpec(col_match_spec, should_match_unique_row = 1)
1058
1059        rows = self.__fetchRows(n_match_spec, cursor = cursor)
1060        if len(rows) == 0:
1061            raise eNoMatchingRows, "no row matches %s" % repr(n_match_spec)
1062
1063        if len(rows) > 1:
1064            raise eInternalError, "unique where clause shouldn't return > 1 row"
1065
1066        return rows[0]
1067
1068
1069    #####################
1070    # fetchRows(col_match_spec)
1071    #
1072    # Ex:
1073    #    a_row_list = tbl.fetchRows( ("order_id", 1) )
1074    #    a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )
1075
1076
1077    def fetchRows(self, col_match_spec = None, cursor = None,
1078		  where = None, order_by = None, limit_to = None,
1079		  skip_to = None, join = None):
1080        n_match_spec = self._fixColMatchSpec(col_match_spec)
1081
1082        return self.__fetchRows(n_match_spec,
1083                                cursor = cursor,
1084                                where = where,
1085                                order_by = order_by,
1086                                limit_to = limit_to,
1087                                skip_to = skip_to,
1088                                join = join)
1089
1090    def fetchRowCount (self, col_match_spec = None,
1091		       cursor = None, where = None):
1092        n_match_spec = self._fixColMatchSpec(col_match_spec)
1093        sql_where_list = self.__buildWhereClause (n_match_spec,where)
1094	sql = "select count(*) from %s" % self.__table_name
1095        if sql_where_list:
1096            sql = "%s where %s" % (sql,string.join(sql_where_list," and "))
1097        if cursor is None:
1098          cursor = self.db.defaultCursor()
1099        dlog(DEV_SELECT,sql)
1100        cursor.execute(sql)
1101        try:
1102            count, = cursor.fetchone()
1103        except TypeError:
1104            count = 0
1105        return count
1106
1107
1108    #####################
1109    # fetchAllRows()
1110    #
1111    # Ex:
1112    #    a_row_list = tbl.fetchRows( ("order_id", 1) )
1113    #    a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )
1114
1115    def fetchAllRows(self):
1116        try:
1117            return self.__fetchRows([])
1118        except eNoMatchingRows:
1119            # else return empty list...
1120            return self.__defaultRowListClass()
1121
1122    def newRow(self,replace=0):
1123        row = self.__defaultRowClass(self,None,create=1,replace=replace)
1124        for (cname, ctype, opts) in self.__column_list:
1125            if opts['default'] is not None and ctype is not kIncInteger:
1126                row[cname] = opts['default']
1127        return row
1128
1129class Row:
1130    __instance_data_locked  = 0
1131    def subclassinit(self):
1132        pass
1133    def __init__(self,_table,data_dict,create=0,joined_cols = None,replace=0):
1134
1135        self._inside_getattr = 0  # stop recursive __getattr__
1136        self._table = _table
1137        self._should_insert = create or replace
1138        self._should_replace = replace
1139        self._rowInactive = None
1140        self._joinedRows = []
1141
1142        self.__pk_match_spec = None
1143        self.__vcoldata = {}
1144        self.__inc_coldata = {}
1145
1146        self.__joined_cols_dict = {}
1147        for a_col in joined_cols or []:
1148            self.__joined_cols_dict[a_col] = 1
1149
1150        if create:
1151            self.__coldata = {}
1152        else:
1153            if type(data_dict) != type({}):
1154                raise eInternalError, "rowdict instantiate with bad data_dict"
1155            self.__coldata = data_dict
1156            self.__unpackVColumn()
1157
1158        self.markClean()
1159
1160        self.subclassinit()
1161        self.__instance_data_locked = 1
1162
1163    def joinRowData(self,another_row):
1164        self._joinedRows.append(another_row)
1165
1166    def getPKMatchSpec(self):
1167        return self.__pk_match_spec
1168
1169    def markClean(self):
1170        self.__vcolchanged = 0
1171        self.__colchanged_dict = {}
1172
1173        for key in self.__inc_coldata.keys():
1174            self.__coldata[key] = self.__coldata.get(key, 0) + self.__inc_coldata[key]
1175
1176        self.__inc_coldata = {}
1177
1178        if not self._should_insert:
1179            # rebuild primary column match spec
1180            new_match_spec = []
1181            for col_name in self._table.getPrimaryKeyList():
1182                try:
1183                    rdata = self[col_name]
1184                except KeyError:
1185                    raise eInternalError, "must have primary key data filled in to save %s:Row(col:%s)" % (self._table.getTableName(),col_name)
1186
1187                new_match_spec.append( (col_name, rdata) )
1188            self.__pk_match_spec = new_match_spec
1189
1190    def __unpackVColumn(self):
1191        if self._table.hasValueColumn():
1192            pass
1193
1194    def __packVColumn(self):
1195        if self._table.hasValueColumn():
1196            pass
1197
1198    ## ----- utility stuff ----------------------------------
1199
1200    def __del__(self):
1201        # check for unsaved changes
1202        changed_list = self.changedList()
1203        if len(changed_list):
1204            info = "unsaved Row for table (%s) lost, call discard() to avoid this error. Lost changes: %s\n" % (self._table.getTableName(), repr(changed_list)[:256])
1205            if 0:
1206                raise eUnsavedObjectLost, info
1207            else:
1208                sys.stderr.write(info)
1209
1210
1211    def __repr__(self):
1212        return "Row from (%s): %s" % (self._table.getTableName(),repr(self.__coldata) + repr(self.__vcoldata))
1213
1214    ## ---- class emulation --------------------------------
1215
1216    def __getattr__(self,key):
1217        if self._inside_getattr:
1218          raise AttributeError, "recursively called __getattr__ (%s,%s)" % (key,self._table.getTableName())
1219        try:
1220            self._inside_getattr = 1
1221            try:
1222                return self[key]
1223            except KeyError:
1224                if self._table.hasColumn(key) or self._table.hasVColumn(key):
1225                    return None
1226                else:
1227                    raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName())
1228        finally:
1229            self._inside_getattr = 0
1230
1231    def __setattr__(self,key,val):
1232        if not self.__instance_data_locked:
1233            self.__dict__[key] = val
1234        else:
1235            my_dict = self.__dict__
1236            if my_dict.has_key(key):
1237                my_dict[key] = val
1238            else:
1239                # try and put it into the rowdata
1240                try:
1241                    self[key] = val
1242                except KeyError, reason:
1243                    raise AttributeError, reason
1244
1245
1246    ## ---- dict emulation ---------------------------------
1247
1248    def __getitem__(self,key):
1249        self.checkRowActive()
1250
1251        try:
1252            c_type = self._table.columnType(key)
1253        except eNoSuchColumn, reason:
1254            # Ugh, this sucks, we can't determine the type for a joined
1255            # row, so we just default to kVarString and let the code below
1256            # determine if this is a joined column or not
1257            c_type = kVarString
1258
1259        if c_type == kIncInteger:
1260            c_data = self.__coldata.get(key, 0)
1261            if c_data is None: c_data = 0
1262            i_data = self.__inc_coldata.get(key, 0)
1263            if i_data is None: i_data = 0
1264            return c_data + i_data
1265
1266        try:
1267            return self.__coldata[key]
1268        except KeyError:
1269            try:
1270                return self.__vcoldata[key]
1271            except KeyError:
1272                for a_joined_row in self._joinedRows:
1273                    try:
1274                        return a_joined_row[key]
1275                    except KeyError:
1276                        pass
1277
1278                raise KeyError, "unknown column %s in %s" % (key,self)
1279
1280    def __setitem__(self,key,data):
1281        self.checkRowActive()
1282
1283        try:
1284            newdata = self._table.convertDataForColumn(data,key)
1285        except eNoSuchColumn, reason:
1286            raise KeyError, reason
1287
1288        if self._table.hasColumn(key):
1289            self.__coldata[key] = newdata
1290            self.__colchanged_dict[key] = 1
1291        elif self._table.hasVColumn(key):
1292            self.__vcoldata[key] = newdata
1293            self.__vcolchanged = 1
1294        else:
1295            for a_joined_row in self._joinedRows:
1296                try:
1297                    a_joined_row[key] = data
1298                    return
1299                except KeyError:
1300                    pass
1301            raise KeyError, "unknown column name %s" % key
1302
1303    def __delitem__(self,key,data):
1304        self.checkRowActive()
1305
1306        if self.table.hasVColumn(key):
1307            del self.__vcoldata[key]
1308        else:
1309            for a_joined_row in self._joinedRows:
1310                try:
1311                    del a_joined_row[key]
1312                    return
1313                except KeyError:
1314                    pass
1315            raise KeyError, "unknown column name %s" % key
1316
1317
1318    def copyFrom(self,source):
1319        for name,t,options in self._table.getColumnList():
1320            if not options.has_key("autoincrement"):
1321                self[name] = source[name]
1322
1323
1324    # make sure that .keys(), and .items() come out in a nice order!
1325
1326    def keys(self):
1327        self.checkRowActive()
1328
1329        key_list = []
1330        for name,t,options in self._table.getColumnList():
1331            key_list.append(name)
1332        for name in self.__joined_cols_dict.keys():
1333            key_list.append(name)
1334
1335        for a_joined_row in self._joinedRows:
1336            key_list = key_list + a_joined_row.keys()
1337
1338        return key_list
1339
1340
1341    def items(self):
1342        self.checkRowActive()
1343
1344        item_list = []
1345        for name,t,options in self._table.getColumnList():
1346            item_list.append( (name,self[name]) )
1347        for name in self.__joined_cols_dict.keys():
1348            item_list.append( (name,self[name]) )
1349
1350        for a_joined_row in self._joinedRows:
1351            item_list = item_list + a_joined_row.items()
1352
1353
1354        return item_list
1355
1356    def values(elf):
1357        self.checkRowActive()
1358
1359        value_list = self.__coldata.values() + self.__vcoldata.values()
1360
1361        for a_joined_row in self._joinedRows:
1362            value_list = value_list + a_joined_row.values()
1363
1364        return value_list
1365
1366
1367    def __len__(self):
1368        self.checkRowActive()
1369
1370        my_len = len(self.__coldata) + len(self.__vcoldata)
1371
1372        for a_joined_row in self._joinedRows:
1373            my_len = my_len + len(a_joined_row)
1374
1375        return my_len
1376
1377    def has_key(self,key):
1378        self.checkRowActive()
1379
1380        if self.__coldata.has_key(key) or self.__vcoldata.has_key(key):
1381            return 1
1382        else:
1383
1384            for a_joined_row in self._joinedRows:
1385                if a_joined_row.has_key(key):
1386                    return 1
1387            return 0
1388
1389    def get(self,key,default = None):
1390        self.checkRowActive()
1391
1392
1393
1394        if self.__coldata.has_key(key):
1395            return self.__coldata[key]
1396        elif self.__vcoldata.has_key(key):
1397            return self.__vcoldata[key]
1398        else:
1399            for a_joined_row in self._joinedRows:
1400                try:
1401                    return a_joined_row.get(key,default)
1402                except eNoSuchColumn:
1403                    pass
1404
1405            if self._table.hasColumn(key):
1406                return default
1407
1408            raise eNoSuchColumn, "no such column %s" % key
1409
1410    def inc(self,key,count=1):
1411        self.checkRowActive()
1412
1413        if self._table.hasColumn(key):
1414            try:
1415                self.__inc_coldata[key] = self.__inc_coldata[key] + count
1416            except KeyError:
1417                self.__inc_coldata[key] = count
1418
1419            self.__colchanged_dict[key] = 1
1420        else:
1421            raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName())
1422
1423
1424    ## ----------------------------------
1425    ## real interface
1426
1427
1428    def fillDefaults(self):
1429        for field_def in self._table.fieldList():
1430            name,type,size,options = field_def
1431            if options.has_key("default"):
1432                self[name] = options["default"]
1433
1434    ###############
1435    # changedList()
1436    #
1437    # returns a list of tuples for the columns which have changed
1438    #
1439    #   changedList() -> [ ('name', 'fred'), ('age', 20) ]
1440
1441    def changedList(self):
1442        if self.__vcolchanged:
1443            self.__packVColumn()
1444
1445        changed_list = []
1446        for a_col in self.__colchanged_dict.keys():
1447            changed_list.append( (a_col,self.get(a_col,None),self.__inc_coldata.get(a_col,None)) )
1448
1449        return changed_list
1450
1451    def discard(self):
1452        self.__coldata = None
1453        self.__vcoldata = None
1454        self.__colchanged_dict = {}
1455        self.__vcolchanged = 0
1456
1457    def delete(self,cursor = None):
1458        self.checkRowActive()
1459
1460
1461        fromTable = self._table
1462        curs = cursor
1463        fromTable.r_deleteRow(self,cursor=curs)
1464        self._rowInactive = "deleted"
1465
1466    def save(self,cursor = None):
1467        toTable = self._table
1468
1469        self.checkRowActive()
1470
1471        if self._should_insert:
1472            toTable.r_insertRow(self,replace=self._should_replace)
1473            self._should_insert = 0
1474            self._should_replace = 0
1475            self.markClean()  # rebuild the primary key list
1476        else:
1477            curs = cursor
1478            toTable.r_updateRow(self,cursor = curs)
1479
1480        # the table will mark us clean!
1481        # self.markClean()
1482
1483    def checkRowActive(self):
1484        if self._rowInactive:
1485            raise eInvalidData, "row is inactive: %s" % self._rowInactive
1486
1487    def databaseSizeForColumn(self,key):
1488        return self._table.databaseSizeForData_ColumnName_(self[key],key)
1489
1490
1491if __name__ == "__main__":
1492    print "run odb_test.py"
1493