1import math
2import sys
3
4from flask import abort
5from flask import render_template
6from flask import request
7from peewee import Database
8from peewee import DoesNotExist
9from peewee import Model
10from peewee import Proxy
11from peewee import SelectQuery
12from playhouse.db_url import connect as db_url_connect
13
14
15class PaginatedQuery(object):
16    def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
17                 check_bounds=False):
18        self.paginate_by = paginate_by
19        self.page_var = page_var
20        self.page = page or None
21        self.check_bounds = check_bounds
22
23        if isinstance(query_or_model, SelectQuery):
24            self.query = query_or_model
25            self.model = self.query.model
26        else:
27            self.model = query_or_model
28            self.query = self.model.select()
29
30    def get_page(self):
31        if self.page is not None:
32            return self.page
33
34        curr_page = request.args.get(self.page_var)
35        if curr_page and curr_page.isdigit():
36            return max(1, int(curr_page))
37        return 1
38
39    def get_page_count(self):
40        if not hasattr(self, '_page_count'):
41            self._page_count = int(math.ceil(
42                float(self.query.count()) / self.paginate_by))
43        return self._page_count
44
45    def get_object_list(self):
46        if self.check_bounds and self.get_page() > self.get_page_count():
47            abort(404)
48        return self.query.paginate(self.get_page(), self.paginate_by)
49
50
51def get_object_or_404(query_or_model, *query):
52    if not isinstance(query_or_model, SelectQuery):
53        query_or_model = query_or_model.select()
54    try:
55        return query_or_model.where(*query).get()
56    except DoesNotExist:
57        abort(404)
58
59def object_list(template_name, query, context_variable='object_list',
60                paginate_by=20, page_var='page', page=None, check_bounds=True,
61                **kwargs):
62    paginated_query = PaginatedQuery(
63        query,
64        paginate_by=paginate_by,
65        page_var=page_var,
66        page=page,
67        check_bounds=check_bounds)
68    kwargs[context_variable] = paginated_query.get_object_list()
69    return render_template(
70        template_name,
71        pagination=paginated_query,
72        page=paginated_query.get_page(),
73        **kwargs)
74
75def get_current_url():
76    if not request.query_string:
77        return request.path
78    return '%s?%s' % (request.path, request.query_string)
79
80def get_next_url(default='/'):
81    if request.args.get('next'):
82        return request.args['next']
83    elif request.form.get('next'):
84        return request.form['next']
85    return default
86
87class FlaskDB(object):
88    def __init__(self, app=None, database=None, model_class=Model):
89        self.database = None  # Reference to actual Peewee database instance.
90        self.base_model_class = model_class
91        self._app = app
92        self._db = database  # dict, url, Database, or None (default).
93        if app is not None:
94            self.init_app(app)
95
96    def init_app(self, app):
97        self._app = app
98
99        if self._db is None:
100            if 'DATABASE' in app.config:
101                initial_db = app.config['DATABASE']
102            elif 'DATABASE_URL' in app.config:
103                initial_db = app.config['DATABASE_URL']
104            else:
105                raise ValueError('Missing required configuration data for '
106                                 'database: DATABASE or DATABASE_URL.')
107        else:
108            initial_db = self._db
109
110        self._load_database(app, initial_db)
111        self._register_handlers(app)
112
113    def _load_database(self, app, config_value):
114        if isinstance(config_value, Database):
115            database = config_value
116        elif isinstance(config_value, dict):
117            database = self._load_from_config_dict(dict(config_value))
118        else:
119            # Assume a database connection URL.
120            database = db_url_connect(config_value)
121
122        if isinstance(self.database, Proxy):
123            self.database.initialize(database)
124        else:
125            self.database = database
126
127    def _load_from_config_dict(self, config_dict):
128        try:
129            name = config_dict.pop('name')
130            engine = config_dict.pop('engine')
131        except KeyError:
132            raise RuntimeError('DATABASE configuration must specify a '
133                               '`name` and `engine`.')
134
135        if '.' in engine:
136            path, class_name = engine.rsplit('.', 1)
137        else:
138            path, class_name = 'peewee', engine
139
140        try:
141            __import__(path)
142            module = sys.modules[path]
143            database_class = getattr(module, class_name)
144            assert issubclass(database_class, Database)
145        except ImportError:
146            raise RuntimeError('Unable to import %s' % engine)
147        except AttributeError:
148            raise RuntimeError('Database engine not found %s' % engine)
149        except AssertionError:
150            raise RuntimeError('Database engine not a subclass of '
151                               'peewee.Database: %s' % engine)
152
153        return database_class(name, **config_dict)
154
155    def _register_handlers(self, app):
156        app.before_request(self.connect_db)
157        app.teardown_request(self.close_db)
158
159    def get_model_class(self):
160        if self.database is None:
161            raise RuntimeError('Database must be initialized.')
162
163        class BaseModel(self.base_model_class):
164            class Meta:
165                database = self.database
166
167        return BaseModel
168
169    @property
170    def Model(self):
171        if self._app is None:
172            database = getattr(self, 'database', None)
173            if database is None:
174                self.database = Proxy()
175
176        if not hasattr(self, '_model_class'):
177            self._model_class = self.get_model_class()
178        return self._model_class
179
180    def connect_db(self):
181        self.database.connect()
182
183    def close_db(self, exc):
184        if not self.database.is_closed():
185            self.database.close()
186