1# Copyright 2019 New Vector Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15from typing import TYPE_CHECKING, List
16
17from prometheus_client import Gauge
18
19from synapse.api.errors import HttpResponseException
20from synapse.events import EventBase
21from synapse.federation.persistence import TransactionActions
22from synapse.federation.units import Edu, Transaction
23from synapse.logging.opentracing import (
24    extract_text_map,
25    set_tag,
26    start_active_span_follows_from,
27    tags,
28    whitelisted_homeserver,
29)
30from synapse.types import JsonDict
31from synapse.util import json_decoder
32from synapse.util.metrics import measure_func
33
34if TYPE_CHECKING:
35    import synapse.server
36
37logger = logging.getLogger(__name__)
38
39last_pdu_ts_metric = Gauge(
40    "synapse_federation_last_sent_pdu_time",
41    "The timestamp of the last PDU which was successfully sent to the given domain",
42    labelnames=("server_name",),
43)
44
45
46class TransactionManager:
47    """Helper class which handles building and sending transactions
48
49    shared between PerDestinationQueue objects
50    """
51
52    def __init__(self, hs: "synapse.server.HomeServer"):
53        self._server_name = hs.hostname
54        self.clock = hs.get_clock()  # nb must be called this for @measure_func
55        self._store = hs.get_datastore()
56        self._transaction_actions = TransactionActions(self._store)
57        self._transport_layer = hs.get_federation_transport_client()
58
59        self._federation_metrics_domains = (
60            hs.config.federation.federation_metrics_domains
61        )
62
63        # HACK to get unique tx id
64        self._next_txn_id = int(self.clock.time_msec())
65
66    @measure_func("_send_new_transaction")
67    async def send_new_transaction(
68        self,
69        destination: str,
70        pdus: List[EventBase],
71        edus: List[Edu],
72    ) -> None:
73        """
74        Args:
75            destination: The destination to send to (e.g. 'example.org')
76            pdus: In-order list of PDUs to send
77            edus: List of EDUs to send
78        """
79
80        # Make a transaction-sending opentracing span. This span follows on from
81        # all the edus in that transaction. This needs to be done since there is
82        # no active span here, so if the edus were not received by the remote the
83        # span would have no causality and it would be forgotten.
84
85        span_contexts = []
86        keep_destination = whitelisted_homeserver(destination)
87
88        for edu in edus:
89            context = edu.get_context()
90            if context:
91                span_contexts.append(extract_text_map(json_decoder.decode(context)))
92            if keep_destination:
93                edu.strip_context()
94
95        with start_active_span_follows_from("send_transaction", span_contexts):
96            logger.debug("TX [%s] _attempt_new_transaction", destination)
97
98            txn_id = str(self._next_txn_id)
99
100            logger.debug(
101                "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
102                destination,
103                txn_id,
104                len(pdus),
105                len(edus),
106            )
107
108            transaction = Transaction(
109                origin_server_ts=int(self.clock.time_msec()),
110                transaction_id=txn_id,
111                origin=self._server_name,
112                destination=destination,
113                pdus=[p.get_pdu_json() for p in pdus],
114                edus=[edu.get_dict() for edu in edus],
115            )
116
117            self._next_txn_id += 1
118
119            logger.info(
120                "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
121                destination,
122                txn_id,
123                transaction.transaction_id,
124                len(pdus),
125                len(edus),
126            )
127
128            # Actually send the transaction
129
130            # FIXME (erikj): This is a bit of a hack to make the Pdu age
131            # keys work
132            # FIXME (richardv): I also believe it no longer works. We (now?) store
133            #  "age_ts" in "unsigned" rather than at the top level. See
134            #  https://github.com/matrix-org/synapse/issues/8429.
135            def json_data_cb() -> JsonDict:
136                data = transaction.get_dict()
137                now = int(self.clock.time_msec())
138                if "pdus" in data:
139                    for p in data["pdus"]:
140                        if "age_ts" in p:
141                            unsigned = p.setdefault("unsigned", {})
142                            unsigned["age"] = now - int(p["age_ts"])
143                            del p["age_ts"]
144                return data
145
146            try:
147                response = await self._transport_layer.send_transaction(
148                    transaction, json_data_cb
149                )
150            except HttpResponseException as e:
151                code = e.code
152
153                set_tag(tags.ERROR, True)
154
155                logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
156                raise
157
158            logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
159
160            for e_id, r in response.get("pdus", {}).items():
161                if "error" in r:
162                    logger.warning(
163                        "TX [%s] {%s} Remote returned error for %s: %s",
164                        destination,
165                        txn_id,
166                        e_id,
167                        r,
168                    )
169
170            if pdus and destination in self._federation_metrics_domains:
171                last_pdu = pdus[-1]
172                last_pdu_ts_metric.labels(server_name=destination).set(
173                    last_pdu.origin_server_ts / 1000
174                )
175