1# Copyright 2020 Matrix.org Foundation C.I.C.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import Any, Dict, List, Optional, Tuple, Union, cast
15
16import attr
17
18from synapse.api.constants import LoginType
19from synapse.api.errors import StoreError
20from synapse.storage._base import SQLBaseStore, db_to_json
21from synapse.storage.database import LoggingTransaction
22from synapse.types import JsonDict
23from synapse.util import json_encoder, stringutils
24
25
26@attr.s(slots=True)
27class UIAuthSessionData:
28    session_id = attr.ib(type=str)
29    # The dictionary from the client root level, not the 'auth' key.
30    clientdict = attr.ib(type=JsonDict)
31    # The URI and method the session was intiatied with. These are checked at
32    # each stage of the authentication to ensure that the asked for operation
33    # has not changed.
34    uri = attr.ib(type=str)
35    method = attr.ib(type=str)
36    # A string description of the operation that the current authentication is
37    # authorising.
38    description = attr.ib(type=str)
39
40
41class UIAuthWorkerStore(SQLBaseStore):
42    """
43    Manage user interactive authentication sessions.
44    """
45
46    async def create_ui_auth_session(
47        self,
48        clientdict: JsonDict,
49        uri: str,
50        method: str,
51        description: str,
52    ) -> UIAuthSessionData:
53        """
54        Creates a new user interactive authentication session.
55
56        The session can be used to track the stages necessary to authenticate a
57        user across multiple HTTP requests.
58
59        Args:
60            clientdict:
61                The dictionary from the client root level, not the 'auth' key.
62            uri:
63                The URI this session was initiated with, this is checked at each
64                stage of the authentication to ensure that the asked for
65                operation has not changed.
66            method:
67                The method this session was initiated with, this is checked at each
68                stage of the authentication to ensure that the asked for
69                operation has not changed.
70            description:
71                A string description of the operation that the current
72                authentication is authorising.
73        Returns:
74            The newly created session.
75        Raises:
76            StoreError if a unique session ID cannot be generated.
77        """
78        # The clientdict gets stored as JSON.
79        clientdict_json = json_encoder.encode(clientdict)
80
81        # autogen a session ID and try to create it. We may clash, so just
82        # try a few times till one goes through, giving up eventually.
83        attempts = 0
84        while attempts < 5:
85            session_id = stringutils.random_string(24)
86
87            try:
88                await self.db_pool.simple_insert(
89                    table="ui_auth_sessions",
90                    values={
91                        "session_id": session_id,
92                        "clientdict": clientdict_json,
93                        "uri": uri,
94                        "method": method,
95                        "description": description,
96                        "serverdict": "{}",
97                        "creation_time": self.hs.get_clock().time_msec(),
98                    },
99                    desc="create_ui_auth_session",
100                )
101                return UIAuthSessionData(
102                    session_id, clientdict, uri, method, description
103                )
104            except self.db_pool.engine.module.IntegrityError:
105                attempts += 1
106        raise StoreError(500, "Couldn't generate a session ID.")
107
108    async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
109        """Retrieve a UI auth session.
110
111        Args:
112            session_id: The ID of the session.
113        Returns:
114            A dict containing the device information.
115        Raises:
116            StoreError if the session is not found.
117        """
118        result = await self.db_pool.simple_select_one(
119            table="ui_auth_sessions",
120            keyvalues={"session_id": session_id},
121            retcols=("clientdict", "uri", "method", "description"),
122            desc="get_ui_auth_session",
123        )
124
125        result["clientdict"] = db_to_json(result["clientdict"])
126
127        return UIAuthSessionData(session_id, **result)
128
129    async def mark_ui_auth_stage_complete(
130        self,
131        session_id: str,
132        stage_type: str,
133        result: Union[str, bool, JsonDict],
134    ):
135        """
136        Mark a session stage as completed.
137
138        Args:
139            session_id: The ID of the corresponding session.
140            stage_type: The completed stage type.
141            result: The result of the stage verification.
142        Raises:
143            StoreError if the session cannot be found.
144        """
145        # Add (or update) the results of the current stage to the database.
146        #
147        # Note that we need to allow for the same stage to complete multiple
148        # times here so that registration is idempotent.
149        try:
150            await self.db_pool.simple_upsert(
151                table="ui_auth_sessions_credentials",
152                keyvalues={"session_id": session_id, "stage_type": stage_type},
153                values={"result": json_encoder.encode(result)},
154                desc="mark_ui_auth_stage_complete",
155            )
156        except self.db_pool.engine.module.IntegrityError:
157            raise StoreError(400, "Unknown session ID: %s" % (session_id,))
158
159    async def get_completed_ui_auth_stages(
160        self, session_id: str
161    ) -> Dict[str, Union[str, bool, JsonDict]]:
162        """
163        Retrieve the completed stages of a UI authentication session.
164
165        Args:
166            session_id: The ID of the session.
167        Returns:
168            The completed stages mapped to the result of the verification of
169            that auth-type.
170        """
171        results = {}
172        for row in await self.db_pool.simple_select_list(
173            table="ui_auth_sessions_credentials",
174            keyvalues={"session_id": session_id},
175            retcols=("stage_type", "result"),
176            desc="get_completed_ui_auth_stages",
177        ):
178            results[row["stage_type"]] = db_to_json(row["result"])
179
180        return results
181
182    async def set_ui_auth_clientdict(
183        self, session_id: str, clientdict: JsonDict
184    ) -> None:
185        """
186        Store an updated clientdict for a given session ID.
187
188        Args:
189            session_id: The ID of this session as returned from check_auth
190            clientdict:
191                The dictionary from the client root level, not the 'auth' key.
192        """
193        # The clientdict gets stored as JSON.
194        clientdict_json = json_encoder.encode(clientdict)
195
196        await self.db_pool.simple_update_one(
197            table="ui_auth_sessions",
198            keyvalues={"session_id": session_id},
199            updatevalues={"clientdict": clientdict_json},
200            desc="set_ui_auth_client_dict",
201        )
202
203    async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
204        """
205        Store a key-value pair into the sessions data associated with this
206        request. This data is stored server-side and cannot be modified by
207        the client.
208
209        Args:
210            session_id: The ID of this session as returned from check_auth
211            key: The key to store the data under
212            value: The data to store
213        Raises:
214            StoreError if the session cannot be found.
215        """
216        await self.db_pool.runInteraction(
217            "set_ui_auth_session_data",
218            self._set_ui_auth_session_data_txn,
219            session_id,
220            key,
221            value,
222        )
223
224    def _set_ui_auth_session_data_txn(
225        self, txn: LoggingTransaction, session_id: str, key: str, value: Any
226    ):
227        # Get the current value.
228        result = cast(
229            Dict[str, Any],
230            self.db_pool.simple_select_one_txn(
231                txn,
232                table="ui_auth_sessions",
233                keyvalues={"session_id": session_id},
234                retcols=("serverdict",),
235            ),
236        )
237
238        # Update it and add it back to the database.
239        serverdict = db_to_json(result["serverdict"])
240        serverdict[key] = value
241
242        self.db_pool.simple_update_one_txn(
243            txn,
244            table="ui_auth_sessions",
245            keyvalues={"session_id": session_id},
246            updatevalues={"serverdict": json_encoder.encode(serverdict)},
247        )
248
249    async def get_ui_auth_session_data(
250        self, session_id: str, key: str, default: Optional[Any] = None
251    ) -> Any:
252        """
253        Retrieve data stored with set_session_data
254
255        Args:
256            session_id: The ID of this session as returned from check_auth
257            key: The key to store the data under
258            default: Value to return if the key has not been set
259        Raises:
260            StoreError if the session cannot be found.
261        """
262        result = await self.db_pool.simple_select_one(
263            table="ui_auth_sessions",
264            keyvalues={"session_id": session_id},
265            retcols=("serverdict",),
266            desc="get_ui_auth_session_data",
267        )
268
269        serverdict = db_to_json(result["serverdict"])
270
271        return serverdict.get(key, default)
272
273    async def add_user_agent_ip_to_ui_auth_session(
274        self,
275        session_id: str,
276        user_agent: str,
277        ip: str,
278    ):
279        """Add the given user agent / IP to the tracking table"""
280        await self.db_pool.simple_upsert(
281            table="ui_auth_sessions_ips",
282            keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
283            values={},
284            desc="add_user_agent_ip_to_ui_auth_session",
285        )
286
287    async def get_user_agents_ips_to_ui_auth_session(
288        self,
289        session_id: str,
290    ) -> List[Tuple[str, str]]:
291        """Get the given user agents / IPs used during the ui auth process
292
293        Returns:
294            List of user_agent/ip pairs
295        """
296        rows = await self.db_pool.simple_select_list(
297            table="ui_auth_sessions_ips",
298            keyvalues={"session_id": session_id},
299            retcols=("user_agent", "ip"),
300            desc="get_user_agents_ips_to_ui_auth_session",
301        )
302        return [(row["user_agent"], row["ip"]) for row in rows]
303
304    async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
305        """
306        Remove sessions which were last used earlier than the expiration time.
307
308        Args:
309            expiration_time: The latest time that is still considered valid.
310                This is an epoch time in milliseconds.
311
312        """
313        await self.db_pool.runInteraction(
314            "delete_old_ui_auth_sessions",
315            self._delete_old_ui_auth_sessions_txn,
316            expiration_time,
317        )
318
319    def _delete_old_ui_auth_sessions_txn(
320        self, txn: LoggingTransaction, expiration_time: int
321    ):
322        # Get the expired sessions.
323        sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
324        txn.execute(sql, [expiration_time])
325        session_ids = [r[0] for r in txn.fetchall()]
326
327        # Delete the corresponding IP/user agents.
328        self.db_pool.simple_delete_many_txn(
329            txn,
330            table="ui_auth_sessions_ips",
331            column="session_id",
332            values=session_ids,
333            keyvalues={},
334        )
335
336        # If a registration token was used, decrement the pending counter
337        # before deleting the session.
338        rows = self.db_pool.simple_select_many_txn(
339            txn,
340            table="ui_auth_sessions_credentials",
341            column="session_id",
342            iterable=session_ids,
343            keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
344            retcols=["result"],
345        )
346
347        # Get the tokens used and how much pending needs to be decremented by.
348        token_counts: Dict[str, int] = {}
349        for r in rows:
350            # If registration was successfully completed, the result of the
351            # registration token stage for that session will be True.
352            # If a token was used to authenticate, but registration was
353            # never completed, the result will be the token used.
354            token = db_to_json(r["result"])
355            if isinstance(token, str):
356                token_counts[token] = token_counts.get(token, 0) + 1
357
358        # Update the `pending` counters.
359        if len(token_counts) > 0:
360            token_rows = self.db_pool.simple_select_many_txn(
361                txn,
362                table="registration_tokens",
363                column="token",
364                iterable=list(token_counts.keys()),
365                keyvalues={},
366                retcols=["token", "pending"],
367            )
368            for token_row in token_rows:
369                token = token_row["token"]
370                new_pending = token_row["pending"] - token_counts[token]
371                self.db_pool.simple_update_one_txn(
372                    txn,
373                    table="registration_tokens",
374                    keyvalues={"token": token},
375                    updatevalues={"pending": new_pending},
376                )
377
378        # Delete the corresponding completed credentials.
379        self.db_pool.simple_delete_many_txn(
380            txn,
381            table="ui_auth_sessions_credentials",
382            column="session_id",
383            values=session_ids,
384            keyvalues={},
385        )
386
387        # Finally, delete the sessions.
388        self.db_pool.simple_delete_many_txn(
389            txn,
390            table="ui_auth_sessions",
391            column="session_id",
392            values=session_ids,
393            keyvalues={},
394        )
395
396
397class UIAuthStore(UIAuthWorkerStore):
398    pass
399