1# Copyright 2017 Vector Creations Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines the various valid commands
15
16The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
17allowed to be sent by which side.
18"""
19import abc
20import logging
21from typing import Tuple, Type
22
23from synapse.util import json_decoder, json_encoder
24
25logger = logging.getLogger(__name__)
26
27
28class Command(metaclass=abc.ABCMeta):
29    """The base command class.
30
31    All subclasses must set the NAME variable which equates to the name of the
32    command on the wire.
33
34    A full command line on the wire is constructed from `NAME + " " + to_line()`
35    """
36
37    NAME: str
38
39    @classmethod
40    @abc.abstractmethod
41    def from_line(cls, line):
42        """Deserialises a line from the wire into this command. `line` does not
43        include the command.
44        """
45
46    @abc.abstractmethod
47    def to_line(self) -> str:
48        """Serialises the command for the wire. Does not include the command
49        prefix.
50        """
51
52    def get_logcontext_id(self):
53        """Get a suitable string for the logcontext when processing this command"""
54
55        # by default, we just use the command name.
56        return self.NAME
57
58
59class _SimpleCommand(Command):
60    """An implementation of Command whose argument is just a 'data' string."""
61
62    def __init__(self, data):
63        self.data = data
64
65    @classmethod
66    def from_line(cls, line):
67        return cls(line)
68
69    def to_line(self) -> str:
70        return self.data
71
72
73class ServerCommand(_SimpleCommand):
74    """Sent by the server on new connection and includes the server_name.
75
76    Format::
77
78        SERVER <server_name>
79    """
80
81    NAME = "SERVER"
82
83
84class RdataCommand(Command):
85    """Sent by server when a subscribed stream has an update.
86
87    Format::
88
89        RDATA <stream_name> <instance_name> <token> <row_json>
90
91    The `<token>` may either be a numeric stream id OR "batch". The latter case
92    is used to support sending multiple updates with the same stream ID. This
93    is done by sending an RDATA for each row, with all but the last RDATA having
94    a token of "batch" and the last having the final stream ID.
95
96    The client should batch all incoming RDATA with a token of "batch" (per
97    stream_name) until it sees an RDATA with a numeric stream ID.
98
99    The `<instance_name>` is the source of the new data (usually "master").
100
101    `<token>` of "batch" maps to the instance variable `token` being None.
102
103    An example of a batched series of RDATA::
104
105        RDATA presence master batch ["@foo:example.com", "online", ...]
106        RDATA presence master batch ["@bar:example.com", "online", ...]
107        RDATA presence master 59 ["@baz:example.com", "online", ...]
108    """
109
110    NAME = "RDATA"
111
112    def __init__(self, stream_name, instance_name, token, row):
113        self.stream_name = stream_name
114        self.instance_name = instance_name
115        self.token = token
116        self.row = row
117
118    @classmethod
119    def from_line(cls, line):
120        stream_name, instance_name, token, row_json = line.split(" ", 3)
121        return cls(
122            stream_name,
123            instance_name,
124            None if token == "batch" else int(token),
125            json_decoder.decode(row_json),
126        )
127
128    def to_line(self):
129        return " ".join(
130            (
131                self.stream_name,
132                self.instance_name,
133                str(self.token) if self.token is not None else "batch",
134                json_encoder.encode(self.row),
135            )
136        )
137
138    def get_logcontext_id(self):
139        return "RDATA-" + self.stream_name
140
141
142class PositionCommand(Command):
143    """Sent by an instance to tell others the stream position without needing to
144    send an RDATA.
145
146    Two tokens are sent, the new position and the last position sent by the
147    instance (in an RDATA or other POSITION). The tokens are chosen so that *no*
148    rows were written by the instance between the `prev_token` and `new_token`.
149    (If an instance hasn't sent a position before then the new position can be
150    used for both.)
151
152    Format::
153
154        POSITION <stream_name> <instance_name> <prev_token> <new_token>
155
156    On receipt of a POSITION command instances should check if they have missed
157    any updates, and if so then fetch them out of band. Instances can check this
158    by comparing their view of the current token for the sending instance with
159    the included `prev_token`.
160
161    The `<instance_name>` is the process that sent the command and is the source
162    of the stream.
163    """
164
165    NAME = "POSITION"
166
167    def __init__(self, stream_name, instance_name, prev_token, new_token):
168        self.stream_name = stream_name
169        self.instance_name = instance_name
170        self.prev_token = prev_token
171        self.new_token = new_token
172
173    @classmethod
174    def from_line(cls, line):
175        stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
176        return cls(stream_name, instance_name, int(prev_token), int(new_token))
177
178    def to_line(self):
179        return " ".join(
180            (
181                self.stream_name,
182                self.instance_name,
183                str(self.prev_token),
184                str(self.new_token),
185            )
186        )
187
188
189class ErrorCommand(_SimpleCommand):
190    """Sent by either side if there was an ERROR. The data is a string describing
191    the error.
192    """
193
194    NAME = "ERROR"
195
196
197class PingCommand(_SimpleCommand):
198    """Sent by either side as a keep alive. The data is arbitrary (often timestamp)"""
199
200    NAME = "PING"
201
202
203class NameCommand(_SimpleCommand):
204    """Sent by client to inform the server of the client's identity. The data
205    is the name
206    """
207
208    NAME = "NAME"
209
210
211class ReplicateCommand(Command):
212    """Sent by the client to subscribe to streams.
213
214    Format::
215
216        REPLICATE
217    """
218
219    NAME = "REPLICATE"
220
221    def __init__(self):
222        pass
223
224    @classmethod
225    def from_line(cls, line):
226        return cls()
227
228    def to_line(self):
229        return ""
230
231
232class UserSyncCommand(Command):
233    """Sent by the client to inform the server that a user has started or
234    stopped syncing on this process.
235
236    This is used by the process handling presence (typically the master) to
237    calculate who is online and who is not.
238
239    Includes a timestamp of when the last user sync was.
240
241    Format::
242
243        USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
244
245    Where <state> is either "start" or "end"
246    """
247
248    NAME = "USER_SYNC"
249
250    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
251        self.instance_id = instance_id
252        self.user_id = user_id
253        self.is_syncing = is_syncing
254        self.last_sync_ms = last_sync_ms
255
256    @classmethod
257    def from_line(cls, line):
258        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
259
260        if state not in ("start", "end"):
261            raise Exception("Invalid USER_SYNC state %r" % (state,))
262
263        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
264
265    def to_line(self):
266        return " ".join(
267            (
268                self.instance_id,
269                self.user_id,
270                "start" if self.is_syncing else "end",
271                str(self.last_sync_ms),
272            )
273        )
274
275
276class ClearUserSyncsCommand(Command):
277    """Sent by the client to inform the server that it should drop all
278    information about syncing users sent by the client.
279
280    Mainly used when client is about to shut down.
281
282    Format::
283
284        CLEAR_USER_SYNC <instance_id>
285    """
286
287    NAME = "CLEAR_USER_SYNC"
288
289    def __init__(self, instance_id):
290        self.instance_id = instance_id
291
292    @classmethod
293    def from_line(cls, line):
294        return cls(line)
295
296    def to_line(self):
297        return self.instance_id
298
299
300class FederationAckCommand(Command):
301    """Sent by the client when it has processed up to a given point in the
302    federation stream. This allows the master to drop in-memory caches of the
303    federation stream.
304
305    This must only be sent from one worker (i.e. the one sending federation)
306
307    Format::
308
309        FEDERATION_ACK <instance_name> <token>
310    """
311
312    NAME = "FEDERATION_ACK"
313
314    def __init__(self, instance_name: str, token: int):
315        self.instance_name = instance_name
316        self.token = token
317
318    @classmethod
319    def from_line(cls, line: str) -> "FederationAckCommand":
320        instance_name, token = line.split(" ")
321        return cls(instance_name, int(token))
322
323    def to_line(self) -> str:
324        return "%s %s" % (self.instance_name, self.token)
325
326
327class UserIpCommand(Command):
328    """Sent periodically when a worker sees activity from a client.
329
330    Format::
331
332        USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
333    """
334
335    NAME = "USER_IP"
336
337    def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
338        self.user_id = user_id
339        self.access_token = access_token
340        self.ip = ip
341        self.user_agent = user_agent
342        self.device_id = device_id
343        self.last_seen = last_seen
344
345    @classmethod
346    def from_line(cls, line):
347        user_id, jsn = line.split(" ", 1)
348
349        access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
350
351        return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
352
353    def to_line(self):
354        return (
355            self.user_id
356            + " "
357            + json_encoder.encode(
358                (
359                    self.access_token,
360                    self.ip,
361                    self.user_agent,
362                    self.device_id,
363                    self.last_seen,
364                )
365            )
366        )
367
368
369class RemoteServerUpCommand(_SimpleCommand):
370    """Sent when a worker has detected that a remote server is no longer
371    "down" and retry timings should be reset.
372
373    If sent from a client the server will relay to all other workers.
374
375    Format::
376
377        REMOTE_SERVER_UP <server>
378    """
379
380    NAME = "REMOTE_SERVER_UP"
381
382
383_COMMANDS: Tuple[Type[Command], ...] = (
384    ServerCommand,
385    RdataCommand,
386    PositionCommand,
387    ErrorCommand,
388    PingCommand,
389    NameCommand,
390    ReplicateCommand,
391    UserSyncCommand,
392    FederationAckCommand,
393    UserIpCommand,
394    RemoteServerUpCommand,
395    ClearUserSyncsCommand,
396)
397
398# Map of command name to command type.
399COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
400
401# The commands the server is allowed to send
402VALID_SERVER_COMMANDS = (
403    ServerCommand.NAME,
404    RdataCommand.NAME,
405    PositionCommand.NAME,
406    ErrorCommand.NAME,
407    PingCommand.NAME,
408    RemoteServerUpCommand.NAME,
409)
410
411# The commands the client is allowed to send
412VALID_CLIENT_COMMANDS = (
413    NameCommand.NAME,
414    ReplicateCommand.NAME,
415    PingCommand.NAME,
416    UserSyncCommand.NAME,
417    ClearUserSyncsCommand.NAME,
418    FederationAckCommand.NAME,
419    UserIpCommand.NAME,
420    ErrorCommand.NAME,
421    RemoteServerUpCommand.NAME,
422)
423
424
425def parse_command_from_line(line: str) -> Command:
426    """Parses a command from a received line.
427
428    Line should already be stripped of whitespace and be checked if blank.
429    """
430
431    idx = line.find(" ")
432    if idx >= 0:
433        cmd_name = line[:idx]
434        rest_of_line = line[idx + 1 :]
435    else:
436        cmd_name = line
437        rest_of_line = ""
438
439    cmd_cls = COMMAND_MAP[cmd_name]
440    return cmd_cls.from_line(rest_of_line)
441