1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2021 The Matrix.org Foundation C.I.C.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import collections.abc
16import re
17from typing import (
18    TYPE_CHECKING,
19    Any,
20    Callable,
21    Dict,
22    Iterable,
23    List,
24    Mapping,
25    Optional,
26    Union,
27)
28
29from frozendict import frozendict
30
31from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
32from synapse.api.errors import Codes, SynapseError
33from synapse.api.room_versions import RoomVersion
34from synapse.types import JsonDict
35from synapse.util.async_helpers import yieldable_gather_results
36from synapse.util.frozenutils import unfreeze
37
38from . import EventBase
39
40if TYPE_CHECKING:
41    from synapse.server import HomeServer
42
43# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
44# (?<!stuff) matches if the current position in the string is not preceded
45# by a match for 'stuff'.
46# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
47#       the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
48SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
49
50CANONICALJSON_MAX_INT = (2 ** 53) - 1
51CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
52
53
54def prune_event(event: EventBase) -> EventBase:
55    """Returns a pruned version of the given event, which removes all keys we
56    don't know about or think could potentially be dodgy.
57
58    This is used when we "redact" an event. We want to remove all fields that
59    the user has specified, but we do want to keep necessary information like
60    type, state_key etc.
61    """
62    pruned_event_dict = prune_event_dict(event.room_version, event.get_dict())
63
64    from . import make_event_from_dict
65
66    pruned_event = make_event_from_dict(
67        pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
68    )
69
70    # copy the internal fields
71    pruned_event.internal_metadata.stream_ordering = (
72        event.internal_metadata.stream_ordering
73    )
74
75    pruned_event.internal_metadata.outlier = event.internal_metadata.outlier
76
77    # Mark the event as redacted
78    pruned_event.internal_metadata.redacted = True
79
80    return pruned_event
81
82
83def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
84    """Redacts the event_dict in the same way as `prune_event`, except it
85    operates on dicts rather than event objects
86
87    Returns:
88        A copy of the pruned event dict
89    """
90
91    allowed_keys = [
92        "event_id",
93        "sender",
94        "room_id",
95        "hashes",
96        "signatures",
97        "content",
98        "type",
99        "state_key",
100        "depth",
101        "prev_events",
102        "auth_events",
103        "origin",
104        "origin_server_ts",
105    ]
106
107    # Room versions from before MSC2176 had additional allowed keys.
108    if not room_version.msc2176_redaction_rules:
109        allowed_keys.extend(["prev_state", "membership"])
110
111    event_type = event_dict["type"]
112
113    new_content = {}
114
115    def add_fields(*fields: str) -> None:
116        for field in fields:
117            if field in event_dict["content"]:
118                new_content[field] = event_dict["content"][field]
119
120    if event_type == EventTypes.Member:
121        add_fields("membership")
122        if room_version.msc3375_redaction_rules:
123            add_fields(EventContentFields.AUTHORISING_USER)
124    elif event_type == EventTypes.Create:
125        # MSC2176 rules state that create events cannot be redacted.
126        if room_version.msc2176_redaction_rules:
127            return event_dict
128
129        add_fields("creator")
130    elif event_type == EventTypes.JoinRules:
131        add_fields("join_rule")
132        if room_version.msc3083_join_rules:
133            add_fields("allow")
134    elif event_type == EventTypes.PowerLevels:
135        add_fields(
136            "users",
137            "users_default",
138            "events",
139            "events_default",
140            "state_default",
141            "ban",
142            "kick",
143            "redact",
144        )
145
146        if room_version.msc2176_redaction_rules:
147            add_fields("invite")
148
149        if room_version.msc2716_historical:
150            add_fields("historical")
151
152    elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
153        add_fields("aliases")
154    elif event_type == EventTypes.RoomHistoryVisibility:
155        add_fields("history_visibility")
156    elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
157        add_fields("redacts")
158    elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
159        add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID)
160    elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
161        add_fields(EventContentFields.MSC2716_BATCH_ID)
162    elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
163        add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
164
165    allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
166
167    allowed_fields["content"] = new_content
168
169    unsigned: JsonDict = {}
170    allowed_fields["unsigned"] = unsigned
171
172    event_unsigned = event_dict.get("unsigned", {})
173
174    if "age_ts" in event_unsigned:
175        unsigned["age_ts"] = event_unsigned["age_ts"]
176    if "replaces_state" in event_unsigned:
177        unsigned["replaces_state"] = event_unsigned["replaces_state"]
178
179    return allowed_fields
180
181
182def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
183    """Copy the field in 'src' to 'dst'.
184
185    For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
186    then dst={"foo":{"bar":5}}.
187
188    Args:
189        src: The dict to read from.
190        dst: The dict to modify.
191        field: List of keys to drill down to in 'src'.
192    """
193    if len(field) == 0:  # this should be impossible
194        return
195    if len(field) == 1:  # common case e.g. 'origin_server_ts'
196        if field[0] in src:
197            dst[field[0]] = src[field[0]]
198        return
199
200    # Else is a nested field e.g. 'content.body'
201    # Pop the last field as that's the key to move across and we need the
202    # parent dict in order to access the data. Drill down to the right dict.
203    key_to_move = field.pop(-1)
204    sub_dict = src
205    for sub_field in field:  # e.g. sub_field => "content"
206        if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
207            sub_dict = sub_dict[sub_field]
208        else:
209            return
210
211    if key_to_move not in sub_dict:
212        return
213
214    # Insert the key into the output dictionary, creating nested objects
215    # as required. We couldn't do this any earlier or else we'd need to delete
216    # the empty objects if the key didn't exist.
217    sub_out_dict = dst
218    for sub_field in field:
219        sub_out_dict = sub_out_dict.setdefault(sub_field, {})
220    sub_out_dict[key_to_move] = sub_dict[key_to_move]
221
222
223def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
224    """Return a new dict with only the fields in 'dictionary' which are present
225    in 'fields'.
226
227    If there are no event fields specified then all fields are included.
228    The entries may include '.' characters to indicate sub-fields.
229    So ['content.body'] will include the 'body' field of the 'content' object.
230    A literal '.' character in a field name may be escaped using a '\'.
231
232    Args:
233        dictionary: The dictionary to read from.
234        fields: A list of fields to copy over. Only shallow refs are
235        taken.
236    Returns:
237        A new dictionary with only the given fields. If fields was empty,
238        the same dictionary is returned.
239    """
240    if len(fields) == 0:
241        return dictionary
242
243    # for each field, convert it:
244    # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
245    split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
246
247    # for each element of the output array of arrays:
248    # remove escaping so we can use the right key names.
249    split_fields[:] = [
250        [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
251    ]
252
253    output: JsonDict = {}
254    for field_array in split_fields:
255        _copy_field(dictionary, output, field_array)
256    return output
257
258
259def format_event_raw(d: JsonDict) -> JsonDict:
260    return d
261
262
263def format_event_for_client_v1(d: JsonDict) -> JsonDict:
264    d = format_event_for_client_v2(d)
265
266    sender = d.get("sender")
267    if sender is not None:
268        d["user_id"] = sender
269
270    copy_keys = (
271        "age",
272        "redacted_because",
273        "replaces_state",
274        "prev_content",
275        "invite_room_state",
276        "knock_room_state",
277    )
278    for key in copy_keys:
279        if key in d["unsigned"]:
280            d[key] = d["unsigned"][key]
281
282    return d
283
284
285def format_event_for_client_v2(d: JsonDict) -> JsonDict:
286    drop_keys = (
287        "auth_events",
288        "prev_events",
289        "hashes",
290        "signatures",
291        "depth",
292        "origin",
293        "prev_state",
294    )
295    for key in drop_keys:
296        d.pop(key, None)
297    return d
298
299
300def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
301    d = format_event_for_client_v2(d)
302    d.pop("room_id", None)
303    return d
304
305
306def serialize_event(
307    e: Union[JsonDict, EventBase],
308    time_now_ms: int,
309    *,
310    as_client_event: bool = True,
311    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
312    token_id: Optional[str] = None,
313    only_event_fields: Optional[List[str]] = None,
314    include_stripped_room_state: bool = False,
315) -> JsonDict:
316    """Serialize event for clients
317
318    Args:
319        e
320        time_now_ms
321        as_client_event
322        event_format
323        token_id
324        only_event_fields
325        include_stripped_room_state: Some events can have stripped room state
326            stored in the `unsigned` field. This is required for invite and knock
327            functionality. If this option is False, that state will be removed from the
328            event before it is returned. Otherwise, it will be kept.
329
330    Returns:
331        The serialized event dictionary.
332    """
333
334    # FIXME(erikj): To handle the case of presence events and the like
335    if not isinstance(e, EventBase):
336        return e
337
338    time_now_ms = int(time_now_ms)
339
340    # Should this strip out None's?
341    d = {k: v for k, v in e.get_dict().items()}
342
343    d["event_id"] = e.event_id
344
345    if "age_ts" in d["unsigned"]:
346        d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
347        del d["unsigned"]["age_ts"]
348
349    if "redacted_because" in e.unsigned:
350        d["unsigned"]["redacted_because"] = serialize_event(
351            e.unsigned["redacted_because"], time_now_ms, event_format=event_format
352        )
353
354    if token_id is not None:
355        if token_id == getattr(e.internal_metadata, "token_id", None):
356            txn_id = getattr(e.internal_metadata, "txn_id", None)
357            if txn_id is not None:
358                d["unsigned"]["transaction_id"] = txn_id
359
360    # invite_room_state and knock_room_state are a list of stripped room state events
361    # that are meant to provide metadata about a room to an invitee/knocker. They are
362    # intended to only be included in specific circumstances, such as down sync, and
363    # should not be included in any other case.
364    if not include_stripped_room_state:
365        d["unsigned"].pop("invite_room_state", None)
366        d["unsigned"].pop("knock_room_state", None)
367
368    if as_client_event:
369        d = event_format(d)
370
371    if only_event_fields:
372        if not isinstance(only_event_fields, list) or not all(
373            isinstance(f, str) for f in only_event_fields
374        ):
375            raise TypeError("only_event_fields must be a list of strings")
376        d = only_fields(d, only_event_fields)
377
378    return d
379
380
381class EventClientSerializer:
382    """Serializes events that are to be sent to clients.
383
384    This is used for bundling extra information with any events to be sent to
385    clients.
386    """
387
388    def __init__(self, hs: "HomeServer"):
389        self.store = hs.get_datastore()
390        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
391        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
392
393    async def serialize_event(
394        self,
395        event: Union[JsonDict, EventBase],
396        time_now: int,
397        *,
398        bundle_aggregations: bool = False,
399        **kwargs: Any,
400    ) -> JsonDict:
401        """Serializes a single event.
402
403        Args:
404            event: The event being serialized.
405            time_now: The current time in milliseconds
406            bundle_aggregations: Whether to include the bundled aggregations for this
407                event. Only applies to non-state events. (State events never include
408                bundled aggregations.)
409            **kwargs: Arguments to pass to `serialize_event`
410
411        Returns:
412            The serialized event
413        """
414        # To handle the case of presence events and the like
415        if not isinstance(event, EventBase):
416            return event
417
418        serialized_event = serialize_event(event, time_now, **kwargs)
419
420        # Check if there are any bundled aggregations to include with the event.
421        #
422        # Do not bundle aggregations if any of the following at true:
423        #
424        # * Support is disabled via the configuration or the caller.
425        # * The event is a state event.
426        # * The event has been redacted.
427        if (
428            self._msc1849_enabled
429            and bundle_aggregations
430            and not event.is_state()
431            and not event.internal_metadata.is_redacted()
432        ):
433            await self._injected_bundled_aggregations(event, time_now, serialized_event)
434
435        return serialized_event
436
437    async def _injected_bundled_aggregations(
438        self, event: EventBase, time_now: int, serialized_event: JsonDict
439    ) -> None:
440        """Potentially injects bundled aggregations into the unsigned portion of the serialized event.
441
442        Args:
443            event: The event being serialized.
444            time_now: The current time in milliseconds
445            serialized_event: The serialized event which may be modified.
446
447        """
448        # Do not bundle aggregations for an event which represents an edit or an
449        # annotation. It does not make sense for them to have related events.
450        relates_to = event.content.get("m.relates_to")
451        if isinstance(relates_to, (dict, frozendict)):
452            relation_type = relates_to.get("rel_type")
453            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
454                return
455
456        event_id = event.event_id
457        room_id = event.room_id
458
459        # The bundled aggregations to include.
460        aggregations = {}
461
462        annotations = await self.store.get_aggregation_groups_for_event(
463            event_id, room_id
464        )
465        if annotations.chunk:
466            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
467
468        references = await self.store.get_relations_for_event(
469            event_id, room_id, RelationTypes.REFERENCE, direction="f"
470        )
471        if references.chunk:
472            aggregations[RelationTypes.REFERENCE] = references.to_dict()
473
474        edit = None
475        if event.type == EventTypes.Message:
476            edit = await self.store.get_applicable_edit(event_id, room_id)
477
478        if edit:
479            # If there is an edit replace the content, preserving existing
480            # relations.
481
482            # Ensure we take copies of the edit content, otherwise we risk modifying
483            # the original event.
484            edit_content = edit.content.copy()
485
486            # Unfreeze the event content if necessary, so that we may modify it below
487            edit_content = unfreeze(edit_content)
488            serialized_event["content"] = edit_content.get("m.new_content", {})
489
490            # Check for existing relations
491            relates_to = event.content.get("m.relates_to")
492            if relates_to:
493                # Keep the relations, ensuring we use a dict copy of the original
494                serialized_event["content"]["m.relates_to"] = relates_to.copy()
495            else:
496                serialized_event["content"].pop("m.relates_to", None)
497
498            aggregations[RelationTypes.REPLACE] = {
499                "event_id": edit.event_id,
500                "origin_server_ts": edit.origin_server_ts,
501                "sender": edit.sender,
502            }
503
504        # If this event is the start of a thread, include a summary of the replies.
505        if self._msc3440_enabled:
506            (
507                thread_count,
508                latest_thread_event,
509            ) = await self.store.get_thread_summary(event_id, room_id)
510            if latest_thread_event:
511                aggregations[RelationTypes.THREAD] = {
512                    # Don't bundle aggregations as this could recurse forever.
513                    "latest_event": await self.serialize_event(
514                        latest_thread_event, time_now, bundle_aggregations=False
515                    ),
516                    "count": thread_count,
517                }
518
519        # If any bundled aggregations were found, include them.
520        if aggregations:
521            serialized_event["unsigned"].setdefault("m.relations", {}).update(
522                aggregations
523            )
524
525    async def serialize_events(
526        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
527    ) -> List[JsonDict]:
528        """Serializes multiple events.
529
530        Args:
531            event
532            time_now: The current time in milliseconds
533            **kwargs: Arguments to pass to `serialize_event`
534
535        Returns:
536            The list of serialized events
537        """
538        return await yieldable_gather_results(
539            self.serialize_event, events, time_now=time_now, **kwargs
540        )
541
542
543def copy_power_levels_contents(
544    old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
545) -> Dict[str, Union[int, Dict[str, int]]]:
546    """Copy the content of a power_levels event, unfreezing frozendicts along the way
547
548    Raises:
549        TypeError if the input does not look like a valid power levels event content
550    """
551    if not isinstance(old_power_levels, collections.abc.Mapping):
552        raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
553
554    power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
555    for k, v in old_power_levels.items():
556
557        if isinstance(v, int):
558            power_levels[k] = v
559            continue
560
561        if isinstance(v, collections.abc.Mapping):
562            h: Dict[str, int] = {}
563            power_levels[k] = h
564            for k1, v1 in v.items():
565                # we should only have one level of nesting
566                if not isinstance(v1, int):
567                    raise TypeError(
568                        "Invalid power_levels value for %s.%s: %r" % (k, k1, v1)
569                    )
570                h[k1] = v1
571            continue
572
573        raise TypeError("Invalid power_levels value for %s: %r" % (k, v))
574
575    return power_levels
576
577
578def validate_canonicaljson(value: Any) -> None:
579    """
580    Ensure that the JSON object is valid according to the rules of canonical JSON.
581
582    See the appendix section 3.1: Canonical JSON.
583
584    This rejects JSON that has:
585    * An integer outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
586    * Floats
587    * NaN, Infinity, -Infinity
588    """
589    if isinstance(value, int):
590        if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value:
591            raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
592
593    elif isinstance(value, float):
594        # Note that Infinity, -Infinity, and NaN are also considered floats.
595        raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON)
596
597    elif isinstance(value, (dict, frozendict)):
598        for v in value.values():
599            validate_canonicaljson(v)
600
601    elif isinstance(value, (list, tuple)):
602        for i in value:
603            validate_canonicaljson(i)
604
605    elif not isinstance(value, (bool, str)) and value is not None:
606        # Other potential JSON values (bool, None, str) are safe.
607        raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
608