1# ext/horizontal_shard.py
2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Horizontal sharding support.
9
10Defines a rudimental 'horizontal sharding' system which allows a Session to
11distribute queries and persistence operations across multiple databases.
12
13For a usage example, see the :ref:`examples_sharding` example included in
14the source distribution.
15
16"""
17
18from .. import inspect
19from .. import util
20from ..orm.query import Query
21from ..orm.session import Session
22
23__all__ = ["ShardedSession", "ShardedQuery"]
24
25
26class ShardedQuery(Query):
27    def __init__(self, *args, **kwargs):
28        super(ShardedQuery, self).__init__(*args, **kwargs)
29        self.id_chooser = self.session.id_chooser
30        self.query_chooser = self.session.query_chooser
31        self._shard_id = None
32
33    def set_shard(self, shard_id):
34        """return a new query, limited to a single shard ID.
35
36        all subsequent operations with the returned query will
37        be against the single shard regardless of other state.
38        """
39
40        q = self._clone()
41        q._shard_id = shard_id
42        return q
43
44    def _execute_and_instances(self, context):
45        def iter_for_shard(shard_id):
46            context.attributes["shard_id"] = context.identity_token = shard_id
47            result = self._connection_from_session(
48                mapper=self._mapper_zero(), shard_id=shard_id
49            ).execute(context.statement, self._params)
50            return self.instances(result, context)
51
52        if context.identity_token is not None:
53            return iter_for_shard(context.identity_token)
54        elif self._shard_id is not None:
55            return iter_for_shard(self._shard_id)
56        else:
57            partial = []
58            for shard_id in self.query_chooser(self):
59                partial.extend(iter_for_shard(shard_id))
60
61            # if some kind of in memory 'sorting'
62            # were done, this is where it would happen
63            return iter(partial)
64
65    def _identity_lookup(
66        self,
67        mapper,
68        primary_key_identity,
69        identity_token=None,
70        lazy_loaded_from=None,
71        **kw
72    ):
73        """override the default Query._identity_lookup method so that we
74        search for a given non-token primary key identity across all
75        possible identity tokens (e.g. shard ids).
76
77        """
78
79        if identity_token is not None:
80            return super(ShardedQuery, self)._identity_lookup(
81                mapper,
82                primary_key_identity,
83                identity_token=identity_token,
84                **kw
85            )
86        else:
87            q = self.session.query(mapper)
88            if lazy_loaded_from:
89                q = q._set_lazyload_from(lazy_loaded_from)
90            for shard_id in self.id_chooser(q, primary_key_identity):
91                obj = super(ShardedQuery, self)._identity_lookup(
92                    mapper, primary_key_identity, identity_token=shard_id, **kw
93                )
94                if obj is not None:
95                    return obj
96
97            return None
98
99    def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
100        """Override the default Query._get_impl() method so that we emit
101        a query to the DB for each possible identity token, if we don't
102        have one already.
103
104        """
105
106        def _db_load_fn(query, primary_key_identity):
107            # load from the database.  The original db_load_fn will
108            # use the given Query object to load from the DB, so our
109            # shard_id is what will indicate the DB that we query from.
110            if self._shard_id is not None:
111                return db_load_fn(self, primary_key_identity)
112            else:
113                ident = util.to_list(primary_key_identity)
114                # build a ShardedQuery for each shard identifier and
115                # try to load from the DB
116                for shard_id in self.id_chooser(self, ident):
117                    q = self.set_shard(shard_id)
118                    o = db_load_fn(q, ident)
119                    if o is not None:
120                        return o
121                else:
122                    return None
123
124        if identity_token is None and self._shard_id is not None:
125            identity_token = self._shard_id
126
127        return super(ShardedQuery, self)._get_impl(
128            primary_key_identity, _db_load_fn, identity_token=identity_token
129        )
130
131
132class ShardedSession(Session):
133    def __init__(
134        self,
135        shard_chooser,
136        id_chooser,
137        query_chooser,
138        shards=None,
139        query_cls=ShardedQuery,
140        **kwargs
141    ):
142        """Construct a ShardedSession.
143
144        :param shard_chooser: A callable which, passed a Mapper, a mapped
145          instance, and possibly a SQL clause, returns a shard ID.  This id
146          may be based off of the attributes present within the object, or on
147          some round-robin scheme. If the scheme is based on a selection, it
148          should set whatever state on the instance to mark it in the future as
149          participating in that shard.
150
151        :param id_chooser: A callable, passed a query and a tuple of identity
152          values, which should return a list of shard ids where the ID might
153          reside.  The databases will be queried in the order of this listing.
154
155        :param query_chooser: For a given Query, returns the list of shard_ids
156          where the query should be issued.  Results from all shards returned
157          will be combined together into a single listing.
158
159        :param shards: A dictionary of string shard names
160          to :class:`~sqlalchemy.engine.Engine` objects.
161
162        """
163        super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
164        self.shard_chooser = shard_chooser
165        self.id_chooser = id_chooser
166        self.query_chooser = query_chooser
167        self.__binds = {}
168        self.connection_callable = self.connection
169        if shards is not None:
170            for k in shards:
171                self.bind_shard(k, shards[k])
172
173    def _choose_shard_and_assign(self, mapper, instance, **kw):
174        if instance is not None:
175            state = inspect(instance)
176            if state.key:
177                token = state.key[2]
178                assert token is not None
179                return token
180            elif state.identity_token:
181                return state.identity_token
182
183        shard_id = self.shard_chooser(mapper, instance, **kw)
184        if instance is not None:
185            state.identity_token = shard_id
186        return shard_id
187
188    def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
189        if shard_id is None:
190            shard_id = self._choose_shard_and_assign(mapper, instance)
191
192        if self.transaction is not None:
193            return self.transaction.connection(mapper, shard_id=shard_id)
194        else:
195            return self.get_bind(
196                mapper, shard_id=shard_id, instance=instance
197            ).contextual_connect(**kwargs)
198
199    def get_bind(
200        self, mapper, shard_id=None, instance=None, clause=None, **kw
201    ):
202        if shard_id is None:
203            shard_id = self._choose_shard_and_assign(
204                mapper, instance, clause=clause
205            )
206        return self.__binds[shard_id]
207
208    def bind_shard(self, shard_id, bind):
209        self.__binds[shard_id] = bind
210