1#!/usr/bin/env python
2
3import datetime
4import sys
5from getpass import getpass
6from optparse import OptionParser
7
8from peewee import *
9from peewee import print_
10from peewee import __version__ as peewee_version
11from playhouse.cockroachdb import CockroachDatabase
12from playhouse.reflection import *
13
14
15HEADER = """from peewee import *%s
16
17database = %s('%s'%s)
18"""
19
20BASE_MODEL = """\
21class BaseModel(Model):
22    class Meta:
23        database = database
24"""
25
26UNKNOWN_FIELD = """\
27class UnknownField(object):
28    def __init__(self, *_, **__): pass
29"""
30
31DATABASE_ALIASES = {
32    CockroachDatabase: ['cockroach', 'cockroachdb', 'crdb'],
33    MySQLDatabase: ['mysql', 'mysqldb'],
34    PostgresqlDatabase: ['postgres', 'postgresql'],
35    SqliteDatabase: ['sqlite', 'sqlite3'],
36}
37
38DATABASE_MAP = dict((value, key)
39                    for key in DATABASE_ALIASES
40                    for value in DATABASE_ALIASES[key])
41
42def make_introspector(database_type, database_name, **kwargs):
43    if database_type not in DATABASE_MAP:
44        err('Unrecognized database, must be one of: %s' %
45            ', '.join(DATABASE_MAP.keys()))
46        sys.exit(1)
47
48    schema = kwargs.pop('schema', None)
49    DatabaseClass = DATABASE_MAP[database_type]
50    db = DatabaseClass(database_name, **kwargs)
51    return Introspector.from_database(db, schema=schema)
52
53def print_models(introspector, tables=None, preserve_order=False,
54                 include_views=False, ignore_unknown=False, snake_case=True):
55    database = introspector.introspect(table_names=tables,
56                                       include_views=include_views,
57                                       snake_case=snake_case)
58
59    db_kwargs = introspector.get_database_kwargs()
60    header = HEADER % (
61        introspector.get_additional_imports(),
62        introspector.get_database_class().__name__,
63        introspector.get_database_name(),
64        ', **%s' % repr(db_kwargs) if db_kwargs else '')
65    print_(header)
66
67    if not ignore_unknown:
68        print_(UNKNOWN_FIELD)
69
70    print_(BASE_MODEL)
71
72    def _print_table(table, seen, accum=None):
73        accum = accum or []
74        foreign_keys = database.foreign_keys[table]
75        for foreign_key in foreign_keys:
76            dest = foreign_key.dest_table
77
78            # In the event the destination table has already been pushed
79            # for printing, then we have a reference cycle.
80            if dest in accum and table not in accum:
81                print_('# Possible reference cycle: %s' % dest)
82
83            # If this is not a self-referential foreign key, and we have
84            # not already processed the destination table, do so now.
85            if dest not in seen and dest not in accum:
86                seen.add(dest)
87                if dest != table:
88                    _print_table(dest, seen, accum + [table])
89
90        print_('class %s(BaseModel):' % database.model_names[table])
91        columns = database.columns[table].items()
92        if not preserve_order:
93            columns = sorted(columns)
94        primary_keys = database.primary_keys[table]
95        for name, column in columns:
96            skip = all([
97                name in primary_keys,
98                name == 'id',
99                len(primary_keys) == 1,
100                column.field_class in introspector.pk_classes])
101            if skip:
102                continue
103            if column.primary_key and len(primary_keys) > 1:
104                # If we have a CompositeKey, then we do not want to explicitly
105                # mark the columns as being primary keys.
106                column.primary_key = False
107
108            is_unknown = column.field_class is UnknownField
109            if is_unknown and ignore_unknown:
110                disp = '%s - %s' % (column.name, column.raw_column_type or '?')
111                print_('    # %s' % disp)
112            else:
113                print_('    %s' % column.get_field())
114
115        print_('')
116        print_('    class Meta:')
117        print_('        table_name = \'%s\'' % table)
118        multi_column_indexes = database.multi_column_indexes(table)
119        if multi_column_indexes:
120            print_('        indexes = (')
121            for fields, unique in sorted(multi_column_indexes):
122                print_('            ((%s), %s),' % (
123                    ', '.join("'%s'" % field for field in fields),
124                    unique,
125                ))
126            print_('        )')
127
128        if introspector.schema:
129            print_('        schema = \'%s\'' % introspector.schema)
130        if len(primary_keys) > 1:
131            pk_field_names = sorted([
132                field.name for col, field in columns
133                if col in primary_keys])
134            pk_list = ', '.join("'%s'" % pk for pk in pk_field_names)
135            print_('        primary_key = CompositeKey(%s)' % pk_list)
136        elif not primary_keys:
137            print_('        primary_key = False')
138        print_('')
139
140        seen.add(table)
141
142    seen = set()
143    for table in sorted(database.model_names.keys()):
144        if table not in seen:
145            if not tables or table in tables:
146                _print_table(table, seen)
147
148def print_header(cmd_line, introspector):
149    timestamp = datetime.datetime.now()
150    print_('# Code generated by:')
151    print_('# python -m pwiz %s' % cmd_line)
152    print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p'))
153    print_('# Database: %s' % introspector.get_database_name())
154    print_('# Peewee version: %s' % peewee_version)
155    print_('')
156
157
158def err(msg):
159    sys.stderr.write('\033[91m%s\033[0m\n' % msg)
160    sys.stderr.flush()
161
162def get_option_parser():
163    parser = OptionParser(usage='usage: %prog [options] database_name')
164    ao = parser.add_option
165    ao('-H', '--host', dest='host')
166    ao('-p', '--port', dest='port', type='int')
167    ao('-u', '--user', dest='user')
168    ao('-P', '--password', dest='password', action='store_true')
169    engines = sorted(DATABASE_MAP)
170    ao('-e', '--engine', dest='engine', default='postgresql', choices=engines,
171       help=('Database type, e.g. sqlite, mysql, postgresql or cockroachdb. '
172             'Default is "postgresql".'))
173    ao('-s', '--schema', dest='schema')
174    ao('-t', '--tables', dest='tables',
175       help=('Only generate the specified tables. Multiple table names should '
176             'be separated by commas.'))
177    ao('-v', '--views', dest='views', action='store_true',
178       help='Generate model classes for VIEWs in addition to tables.')
179    ao('-i', '--info', dest='info', action='store_true',
180       help=('Add database information and other metadata to top of the '
181             'generated file.'))
182    ao('-o', '--preserve-order', action='store_true', dest='preserve_order',
183       help='Model definition column ordering matches source table.')
184    ao('-I', '--ignore-unknown', action='store_true', dest='ignore_unknown',
185       help='Ignore fields whose type cannot be determined.')
186    ao('-L', '--legacy-naming', action='store_true', dest='legacy_naming',
187       help='Use legacy table- and column-name generation.')
188    return parser
189
190def get_connect_kwargs(options):
191    ops = ('host', 'port', 'user', 'schema')
192    kwargs = dict((o, getattr(options, o)) for o in ops if getattr(options, o))
193    if options.password:
194        kwargs['password'] = getpass()
195    return kwargs
196
197
198if __name__ == '__main__':
199    raw_argv = sys.argv
200
201    parser = get_option_parser()
202    options, args = parser.parse_args()
203
204    if len(args) < 1:
205        err('Missing required parameter "database"')
206        parser.print_help()
207        sys.exit(1)
208
209    connect = get_connect_kwargs(options)
210    database = args[-1]
211
212    tables = None
213    if options.tables:
214        tables = [table.strip() for table in options.tables.split(',')
215                  if table.strip()]
216
217    introspector = make_introspector(options.engine, database, **connect)
218    if options.info:
219        cmd_line = ' '.join(raw_argv[1:])
220        print_header(cmd_line, introspector)
221
222    print_models(introspector, tables, options.preserve_order, options.views,
223                 options.ignore_unknown, not options.legacy_naming)
224