1#!/usr/bin/env python 2 3import datetime 4import sys 5from getpass import getpass 6from optparse import OptionParser 7 8from peewee import * 9from peewee import print_ 10from peewee import __version__ as peewee_version 11from playhouse.cockroachdb import CockroachDatabase 12from playhouse.reflection import * 13 14 15HEADER = """from peewee import *%s 16 17database = %s('%s'%s) 18""" 19 20BASE_MODEL = """\ 21class BaseModel(Model): 22 class Meta: 23 database = database 24""" 25 26UNKNOWN_FIELD = """\ 27class UnknownField(object): 28 def __init__(self, *_, **__): pass 29""" 30 31DATABASE_ALIASES = { 32 CockroachDatabase: ['cockroach', 'cockroachdb', 'crdb'], 33 MySQLDatabase: ['mysql', 'mysqldb'], 34 PostgresqlDatabase: ['postgres', 'postgresql'], 35 SqliteDatabase: ['sqlite', 'sqlite3'], 36} 37 38DATABASE_MAP = dict((value, key) 39 for key in DATABASE_ALIASES 40 for value in DATABASE_ALIASES[key]) 41 42def make_introspector(database_type, database_name, **kwargs): 43 if database_type not in DATABASE_MAP: 44 err('Unrecognized database, must be one of: %s' % 45 ', '.join(DATABASE_MAP.keys())) 46 sys.exit(1) 47 48 schema = kwargs.pop('schema', None) 49 DatabaseClass = DATABASE_MAP[database_type] 50 db = DatabaseClass(database_name, **kwargs) 51 return Introspector.from_database(db, schema=schema) 52 53def print_models(introspector, tables=None, preserve_order=False, 54 include_views=False, ignore_unknown=False, snake_case=True): 55 database = introspector.introspect(table_names=tables, 56 include_views=include_views, 57 snake_case=snake_case) 58 59 db_kwargs = introspector.get_database_kwargs() 60 header = HEADER % ( 61 introspector.get_additional_imports(), 62 introspector.get_database_class().__name__, 63 introspector.get_database_name(), 64 ', **%s' % repr(db_kwargs) if db_kwargs else '') 65 print_(header) 66 67 if not ignore_unknown: 68 print_(UNKNOWN_FIELD) 69 70 print_(BASE_MODEL) 71 72 def _print_table(table, seen, accum=None): 73 accum = accum or [] 74 foreign_keys = database.foreign_keys[table] 75 for foreign_key in foreign_keys: 76 dest = foreign_key.dest_table 77 78 # In the event the destination table has already been pushed 79 # for printing, then we have a reference cycle. 80 if dest in accum and table not in accum: 81 print_('# Possible reference cycle: %s' % dest) 82 83 # If this is not a self-referential foreign key, and we have 84 # not already processed the destination table, do so now. 85 if dest not in seen and dest not in accum: 86 seen.add(dest) 87 if dest != table: 88 _print_table(dest, seen, accum + [table]) 89 90 print_('class %s(BaseModel):' % database.model_names[table]) 91 columns = database.columns[table].items() 92 if not preserve_order: 93 columns = sorted(columns) 94 primary_keys = database.primary_keys[table] 95 for name, column in columns: 96 skip = all([ 97 name in primary_keys, 98 name == 'id', 99 len(primary_keys) == 1, 100 column.field_class in introspector.pk_classes]) 101 if skip: 102 continue 103 if column.primary_key and len(primary_keys) > 1: 104 # If we have a CompositeKey, then we do not want to explicitly 105 # mark the columns as being primary keys. 106 column.primary_key = False 107 108 is_unknown = column.field_class is UnknownField 109 if is_unknown and ignore_unknown: 110 disp = '%s - %s' % (column.name, column.raw_column_type or '?') 111 print_(' # %s' % disp) 112 else: 113 print_(' %s' % column.get_field()) 114 115 print_('') 116 print_(' class Meta:') 117 print_(' table_name = \'%s\'' % table) 118 multi_column_indexes = database.multi_column_indexes(table) 119 if multi_column_indexes: 120 print_(' indexes = (') 121 for fields, unique in sorted(multi_column_indexes): 122 print_(' ((%s), %s),' % ( 123 ', '.join("'%s'" % field for field in fields), 124 unique, 125 )) 126 print_(' )') 127 128 if introspector.schema: 129 print_(' schema = \'%s\'' % introspector.schema) 130 if len(primary_keys) > 1: 131 pk_field_names = sorted([ 132 field.name for col, field in columns 133 if col in primary_keys]) 134 pk_list = ', '.join("'%s'" % pk for pk in pk_field_names) 135 print_(' primary_key = CompositeKey(%s)' % pk_list) 136 elif not primary_keys: 137 print_(' primary_key = False') 138 print_('') 139 140 seen.add(table) 141 142 seen = set() 143 for table in sorted(database.model_names.keys()): 144 if table not in seen: 145 if not tables or table in tables: 146 _print_table(table, seen) 147 148def print_header(cmd_line, introspector): 149 timestamp = datetime.datetime.now() 150 print_('# Code generated by:') 151 print_('# python -m pwiz %s' % cmd_line) 152 print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p')) 153 print_('# Database: %s' % introspector.get_database_name()) 154 print_('# Peewee version: %s' % peewee_version) 155 print_('') 156 157 158def err(msg): 159 sys.stderr.write('\033[91m%s\033[0m\n' % msg) 160 sys.stderr.flush() 161 162def get_option_parser(): 163 parser = OptionParser(usage='usage: %prog [options] database_name') 164 ao = parser.add_option 165 ao('-H', '--host', dest='host') 166 ao('-p', '--port', dest='port', type='int') 167 ao('-u', '--user', dest='user') 168 ao('-P', '--password', dest='password', action='store_true') 169 engines = sorted(DATABASE_MAP) 170 ao('-e', '--engine', dest='engine', default='postgresql', choices=engines, 171 help=('Database type, e.g. sqlite, mysql, postgresql or cockroachdb. ' 172 'Default is "postgresql".')) 173 ao('-s', '--schema', dest='schema') 174 ao('-t', '--tables', dest='tables', 175 help=('Only generate the specified tables. Multiple table names should ' 176 'be separated by commas.')) 177 ao('-v', '--views', dest='views', action='store_true', 178 help='Generate model classes for VIEWs in addition to tables.') 179 ao('-i', '--info', dest='info', action='store_true', 180 help=('Add database information and other metadata to top of the ' 181 'generated file.')) 182 ao('-o', '--preserve-order', action='store_true', dest='preserve_order', 183 help='Model definition column ordering matches source table.') 184 ao('-I', '--ignore-unknown', action='store_true', dest='ignore_unknown', 185 help='Ignore fields whose type cannot be determined.') 186 ao('-L', '--legacy-naming', action='store_true', dest='legacy_naming', 187 help='Use legacy table- and column-name generation.') 188 return parser 189 190def get_connect_kwargs(options): 191 ops = ('host', 'port', 'user', 'schema') 192 kwargs = dict((o, getattr(options, o)) for o in ops if getattr(options, o)) 193 if options.password: 194 kwargs['password'] = getpass() 195 return kwargs 196 197 198if __name__ == '__main__': 199 raw_argv = sys.argv 200 201 parser = get_option_parser() 202 options, args = parser.parse_args() 203 204 if len(args) < 1: 205 err('Missing required parameter "database"') 206 parser.print_help() 207 sys.exit(1) 208 209 connect = get_connect_kwargs(options) 210 database = args[-1] 211 212 tables = None 213 if options.tables: 214 tables = [table.strip() for table in options.tables.split(',') 215 if table.strip()] 216 217 introspector = make_introspector(options.engine, database, **connect) 218 if options.info: 219 cmd_line = ' '.join(raw_argv[1:]) 220 print_header(cmd_line, introspector) 221 222 print_models(introspector, tables, options.preserve_order, options.views, 223 options.ignore_unknown, not options.legacy_naming) 224