1# ext/horizontal_shard.py 2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors 3# <see AUTHORS file> 4# 5# This module is part of SQLAlchemy and is released under 6# the MIT License: http://www.opensource.org/licenses/mit-license.php 7 8"""Horizontal sharding support. 9 10Defines a rudimental 'horizontal sharding' system which allows a Session to 11distribute queries and persistence operations across multiple databases. 12 13For a usage example, see the :ref:`examples_sharding` example included in 14the source distribution. 15 16""" 17 18from .. import inspect 19from .. import util 20from ..orm.query import Query 21from ..orm.session import Session 22 23__all__ = ["ShardedSession", "ShardedQuery"] 24 25 26class ShardedQuery(Query): 27 def __init__(self, *args, **kwargs): 28 super(ShardedQuery, self).__init__(*args, **kwargs) 29 self.id_chooser = self.session.id_chooser 30 self.query_chooser = self.session.query_chooser 31 self._shard_id = None 32 33 def set_shard(self, shard_id): 34 """return a new query, limited to a single shard ID. 35 36 all subsequent operations with the returned query will 37 be against the single shard regardless of other state. 38 """ 39 40 q = self._clone() 41 q._shard_id = shard_id 42 return q 43 44 def _execute_and_instances(self, context): 45 def iter_for_shard(shard_id): 46 context.attributes["shard_id"] = context.identity_token = shard_id 47 result = self._connection_from_session( 48 mapper=self._mapper_zero(), shard_id=shard_id 49 ).execute(context.statement, self._params) 50 return self.instances(result, context) 51 52 if context.identity_token is not None: 53 return iter_for_shard(context.identity_token) 54 elif self._shard_id is not None: 55 return iter_for_shard(self._shard_id) 56 else: 57 partial = [] 58 for shard_id in self.query_chooser(self): 59 partial.extend(iter_for_shard(shard_id)) 60 61 # if some kind of in memory 'sorting' 62 # were done, this is where it would happen 63 return iter(partial) 64 65 def _identity_lookup( 66 self, 67 mapper, 68 primary_key_identity, 69 identity_token=None, 70 lazy_loaded_from=None, 71 **kw 72 ): 73 """override the default Query._identity_lookup method so that we 74 search for a given non-token primary key identity across all 75 possible identity tokens (e.g. shard ids). 76 77 """ 78 79 if identity_token is not None: 80 return super(ShardedQuery, self)._identity_lookup( 81 mapper, 82 primary_key_identity, 83 identity_token=identity_token, 84 **kw 85 ) 86 else: 87 q = self.session.query(mapper) 88 if lazy_loaded_from: 89 q = q._set_lazyload_from(lazy_loaded_from) 90 for shard_id in self.id_chooser(q, primary_key_identity): 91 obj = super(ShardedQuery, self)._identity_lookup( 92 mapper, primary_key_identity, identity_token=shard_id, **kw 93 ) 94 if obj is not None: 95 return obj 96 97 return None 98 99 def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): 100 """Override the default Query._get_impl() method so that we emit 101 a query to the DB for each possible identity token, if we don't 102 have one already. 103 104 """ 105 106 def _db_load_fn(query, primary_key_identity): 107 # load from the database. The original db_load_fn will 108 # use the given Query object to load from the DB, so our 109 # shard_id is what will indicate the DB that we query from. 110 if self._shard_id is not None: 111 return db_load_fn(self, primary_key_identity) 112 else: 113 ident = util.to_list(primary_key_identity) 114 # build a ShardedQuery for each shard identifier and 115 # try to load from the DB 116 for shard_id in self.id_chooser(self, ident): 117 q = self.set_shard(shard_id) 118 o = db_load_fn(q, ident) 119 if o is not None: 120 return o 121 else: 122 return None 123 124 if identity_token is None and self._shard_id is not None: 125 identity_token = self._shard_id 126 127 return super(ShardedQuery, self)._get_impl( 128 primary_key_identity, _db_load_fn, identity_token=identity_token 129 ) 130 131 132class ShardedSession(Session): 133 def __init__( 134 self, 135 shard_chooser, 136 id_chooser, 137 query_chooser, 138 shards=None, 139 query_cls=ShardedQuery, 140 **kwargs 141 ): 142 """Construct a ShardedSession. 143 144 :param shard_chooser: A callable which, passed a Mapper, a mapped 145 instance, and possibly a SQL clause, returns a shard ID. This id 146 may be based off of the attributes present within the object, or on 147 some round-robin scheme. If the scheme is based on a selection, it 148 should set whatever state on the instance to mark it in the future as 149 participating in that shard. 150 151 :param id_chooser: A callable, passed a query and a tuple of identity 152 values, which should return a list of shard ids where the ID might 153 reside. The databases will be queried in the order of this listing. 154 155 :param query_chooser: For a given Query, returns the list of shard_ids 156 where the query should be issued. Results from all shards returned 157 will be combined together into a single listing. 158 159 :param shards: A dictionary of string shard names 160 to :class:`~sqlalchemy.engine.Engine` objects. 161 162 """ 163 super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) 164 self.shard_chooser = shard_chooser 165 self.id_chooser = id_chooser 166 self.query_chooser = query_chooser 167 self.__binds = {} 168 self.connection_callable = self.connection 169 if shards is not None: 170 for k in shards: 171 self.bind_shard(k, shards[k]) 172 173 def _choose_shard_and_assign(self, mapper, instance, **kw): 174 if instance is not None: 175 state = inspect(instance) 176 if state.key: 177 token = state.key[2] 178 assert token is not None 179 return token 180 elif state.identity_token: 181 return state.identity_token 182 183 shard_id = self.shard_chooser(mapper, instance, **kw) 184 if instance is not None: 185 state.identity_token = shard_id 186 return shard_id 187 188 def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): 189 if shard_id is None: 190 shard_id = self._choose_shard_and_assign(mapper, instance) 191 192 if self.transaction is not None: 193 return self.transaction.connection(mapper, shard_id=shard_id) 194 else: 195 return self.get_bind( 196 mapper, shard_id=shard_id, instance=instance 197 ).contextual_connect(**kwargs) 198 199 def get_bind( 200 self, mapper, shard_id=None, instance=None, clause=None, **kw 201 ): 202 if shard_id is None: 203 shard_id = self._choose_shard_and_assign( 204 mapper, instance, clause=clause 205 ) 206 return self.__binds[shard_id] 207 208 def bind_shard(self, shard_id, bind): 209 self.__binds[shard_id] = bind 210