1import math 2import sys 3 4from flask import abort 5from flask import render_template 6from flask import request 7from peewee import Database 8from peewee import DoesNotExist 9from peewee import Model 10from peewee import Proxy 11from peewee import SelectQuery 12from playhouse.db_url import connect as db_url_connect 13 14 15class PaginatedQuery(object): 16 def __init__(self, query_or_model, paginate_by, page_var='page', page=None, 17 check_bounds=False): 18 self.paginate_by = paginate_by 19 self.page_var = page_var 20 self.page = page or None 21 self.check_bounds = check_bounds 22 23 if isinstance(query_or_model, SelectQuery): 24 self.query = query_or_model 25 self.model = self.query.model 26 else: 27 self.model = query_or_model 28 self.query = self.model.select() 29 30 def get_page(self): 31 if self.page is not None: 32 return self.page 33 34 curr_page = request.args.get(self.page_var) 35 if curr_page and curr_page.isdigit(): 36 return max(1, int(curr_page)) 37 return 1 38 39 def get_page_count(self): 40 if not hasattr(self, '_page_count'): 41 self._page_count = int(math.ceil( 42 float(self.query.count()) / self.paginate_by)) 43 return self._page_count 44 45 def get_object_list(self): 46 if self.check_bounds and self.get_page() > self.get_page_count(): 47 abort(404) 48 return self.query.paginate(self.get_page(), self.paginate_by) 49 50 51def get_object_or_404(query_or_model, *query): 52 if not isinstance(query_or_model, SelectQuery): 53 query_or_model = query_or_model.select() 54 try: 55 return query_or_model.where(*query).get() 56 except DoesNotExist: 57 abort(404) 58 59def object_list(template_name, query, context_variable='object_list', 60 paginate_by=20, page_var='page', page=None, check_bounds=True, 61 **kwargs): 62 paginated_query = PaginatedQuery( 63 query, 64 paginate_by=paginate_by, 65 page_var=page_var, 66 page=page, 67 check_bounds=check_bounds) 68 kwargs[context_variable] = paginated_query.get_object_list() 69 return render_template( 70 template_name, 71 pagination=paginated_query, 72 page=paginated_query.get_page(), 73 **kwargs) 74 75def get_current_url(): 76 if not request.query_string: 77 return request.path 78 return '%s?%s' % (request.path, request.query_string) 79 80def get_next_url(default='/'): 81 if request.args.get('next'): 82 return request.args['next'] 83 elif request.form.get('next'): 84 return request.form['next'] 85 return default 86 87class FlaskDB(object): 88 def __init__(self, app=None, database=None, model_class=Model): 89 self.database = None # Reference to actual Peewee database instance. 90 self.base_model_class = model_class 91 self._app = app 92 self._db = database # dict, url, Database, or None (default). 93 if app is not None: 94 self.init_app(app) 95 96 def init_app(self, app): 97 self._app = app 98 99 if self._db is None: 100 if 'DATABASE' in app.config: 101 initial_db = app.config['DATABASE'] 102 elif 'DATABASE_URL' in app.config: 103 initial_db = app.config['DATABASE_URL'] 104 else: 105 raise ValueError('Missing required configuration data for ' 106 'database: DATABASE or DATABASE_URL.') 107 else: 108 initial_db = self._db 109 110 self._load_database(app, initial_db) 111 self._register_handlers(app) 112 113 def _load_database(self, app, config_value): 114 if isinstance(config_value, Database): 115 database = config_value 116 elif isinstance(config_value, dict): 117 database = self._load_from_config_dict(dict(config_value)) 118 else: 119 # Assume a database connection URL. 120 database = db_url_connect(config_value) 121 122 if isinstance(self.database, Proxy): 123 self.database.initialize(database) 124 else: 125 self.database = database 126 127 def _load_from_config_dict(self, config_dict): 128 try: 129 name = config_dict.pop('name') 130 engine = config_dict.pop('engine') 131 except KeyError: 132 raise RuntimeError('DATABASE configuration must specify a ' 133 '`name` and `engine`.') 134 135 if '.' in engine: 136 path, class_name = engine.rsplit('.', 1) 137 else: 138 path, class_name = 'peewee', engine 139 140 try: 141 __import__(path) 142 module = sys.modules[path] 143 database_class = getattr(module, class_name) 144 assert issubclass(database_class, Database) 145 except ImportError: 146 raise RuntimeError('Unable to import %s' % engine) 147 except AttributeError: 148 raise RuntimeError('Database engine not found %s' % engine) 149 except AssertionError: 150 raise RuntimeError('Database engine not a subclass of ' 151 'peewee.Database: %s' % engine) 152 153 return database_class(name, **config_dict) 154 155 def _register_handlers(self, app): 156 app.before_request(self.connect_db) 157 app.teardown_request(self.close_db) 158 159 def get_model_class(self): 160 if self.database is None: 161 raise RuntimeError('Database must be initialized.') 162 163 class BaseModel(self.base_model_class): 164 class Meta: 165 database = self.database 166 167 return BaseModel 168 169 @property 170 def Model(self): 171 if self._app is None: 172 database = getattr(self, 'database', None) 173 if database is None: 174 self.database = Proxy() 175 176 if not hasattr(self, '_model_class'): 177 self._model_class = self.get_model_class() 178 return self._model_class 179 180 def connect_db(self): 181 self.database.connect() 182 183 def close_db(self, exc): 184 if not self.database.is_closed(): 185 self.database.close() 186