1import itertools
2import logging
3
4from bidict import bidict
5
6default_logger = logging.getLogger('socketio')
7
8
9class BaseManager(object):
10    """Manage client connections.
11
12    This class keeps track of all the clients and the rooms they are in, to
13    support the broadcasting of messages. The data used by this class is
14    stored in a memory structure, making it appropriate only for single process
15    services. More sophisticated storage backends can be implemented by
16    subclasses.
17    """
18    def __init__(self):
19        self.logger = None
20        self.server = None
21        self.rooms = {}  # self.rooms[namespace][room][sio_sid] = eio_sid
22        self.eio_to_sid = {}
23        self.callbacks = {}
24        self.pending_disconnect = {}
25
26    def set_server(self, server):
27        self.server = server
28
29    def initialize(self):
30        """Invoked before the first request is received. Subclasses can add
31        their initialization code here.
32        """
33        pass
34
35    def get_namespaces(self):
36        """Return an iterable with the active namespace names."""
37        return self.rooms.keys()
38
39    def get_participants(self, namespace, room):
40        """Return an iterable with the active participants in a room."""
41        for sid, eio_sid in self.rooms[namespace][room]._fwdm.copy().items():
42            yield sid, eio_sid
43
44    def connect(self, eio_sid, namespace):
45        """Register a client connection to a namespace."""
46        sid = self.server.eio.generate_id()
47        self.enter_room(sid, namespace, None, eio_sid=eio_sid)
48        self.enter_room(sid, namespace, sid, eio_sid=eio_sid)
49        return sid
50
51    def is_connected(self, sid, namespace):
52        if namespace in self.pending_disconnect and \
53                sid in self.pending_disconnect[namespace]:
54            # the client is in the process of being disconnected
55            return False
56        try:
57            return self.rooms[namespace][None][sid] is not None
58        except KeyError:
59            pass
60
61    def sid_from_eio_sid(self, eio_sid, namespace):
62        try:
63            return self.rooms[namespace][None]._invm[eio_sid]
64        except KeyError:
65            pass
66
67    def eio_sid_from_sid(self, sid, namespace):
68        if namespace in self.rooms:
69            return self.rooms[namespace][None].get(sid)
70
71    def can_disconnect(self, sid, namespace):
72        return self.is_connected(sid, namespace)
73
74    def pre_disconnect(self, sid, namespace):
75        """Put the client in the to-be-disconnected list.
76
77        This allows the client data structures to be present while the
78        disconnect handler is invoked, but still recognize the fact that the
79        client is soon going away.
80        """
81        if namespace not in self.pending_disconnect:
82            self.pending_disconnect[namespace] = []
83        self.pending_disconnect[namespace].append(sid)
84        return self.rooms[namespace][None].get(sid)
85
86    def disconnect(self, sid, namespace):
87        """Register a client disconnect from a namespace."""
88        if namespace not in self.rooms:
89            return
90        rooms = []
91        for room_name, room in self.rooms[namespace].copy().items():
92            if sid in room:
93                rooms.append(room_name)
94        for room in rooms:
95            self.leave_room(sid, namespace, room)
96        if sid in self.callbacks:
97            del self.callbacks[sid]
98        if namespace in self.pending_disconnect and \
99                sid in self.pending_disconnect[namespace]:
100            self.pending_disconnect[namespace].remove(sid)
101            if len(self.pending_disconnect[namespace]) == 0:
102                del self.pending_disconnect[namespace]
103
104    def enter_room(self, sid, namespace, room, eio_sid=None):
105        """Add a client to a room."""
106        if namespace not in self.rooms:
107            self.rooms[namespace] = {}
108        if room not in self.rooms[namespace]:
109            self.rooms[namespace][room] = bidict()
110        if eio_sid is None:
111            eio_sid = self.rooms[namespace][None][sid]
112        self.rooms[namespace][room][sid] = eio_sid
113
114    def leave_room(self, sid, namespace, room):
115        """Remove a client from a room."""
116        try:
117            del self.rooms[namespace][room][sid]
118            if len(self.rooms[namespace][room]) == 0:
119                del self.rooms[namespace][room]
120                if len(self.rooms[namespace]) == 0:
121                    del self.rooms[namespace]
122        except KeyError:
123            pass
124
125    def close_room(self, room, namespace):
126        """Remove all participants from a room."""
127        try:
128            for sid, _ in self.get_participants(namespace, room):
129                self.leave_room(sid, namespace, room)
130        except KeyError:
131            pass
132
133    def get_rooms(self, sid, namespace):
134        """Return the rooms a client is in."""
135        r = []
136        try:
137            for room_name, room in self.rooms[namespace].items():
138                if room_name is not None and sid in room:
139                    r.append(room_name)
140        except KeyError:
141            pass
142        return r
143
144    def emit(self, event, data, namespace, room=None, skip_sid=None,
145             callback=None, **kwargs):
146        """Emit a message to a single client, a room, or all the clients
147        connected to the namespace."""
148        if namespace not in self.rooms or room not in self.rooms[namespace]:
149            return
150        if not isinstance(skip_sid, list):
151            skip_sid = [skip_sid]
152        for sid, eio_sid in self.get_participants(namespace, room):
153            if sid not in skip_sid:
154                if callback is not None:
155                    id = self._generate_ack_id(sid, callback)
156                else:
157                    id = None
158                self.server._emit_internal(eio_sid, event, data, namespace, id)
159
160    def trigger_callback(self, sid, id, data):
161        """Invoke an application callback."""
162        callback = None
163        try:
164            callback = self.callbacks[sid][id]
165        except KeyError:
166            # if we get an unknown callback we just ignore it
167            self._get_logger().warning('Unknown callback received, ignoring.')
168        else:
169            del self.callbacks[sid][id]
170        if callback is not None:
171            callback(*data)
172
173    def _generate_ack_id(self, sid, callback):
174        """Generate a unique identifier for an ACK packet."""
175        if sid not in self.callbacks:
176            self.callbacks[sid] = {0: itertools.count(1)}
177        id = next(self.callbacks[sid][0])
178        self.callbacks[sid][id] = callback
179        return id
180
181    def _get_logger(self):
182        """Get the appropriate logger
183
184        Prevents uninitialized servers in write-only mode from failing.
185        """
186
187        if self.logger:
188            return self.logger
189        elif self.server:
190            return self.server.logger
191        else:
192            return default_logger
193