1# Copyright 2020 The Matrix.org Foundation C.I.C.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15import urllib.parse
16from typing import TYPE_CHECKING, Dict, List, Optional
17from xml.etree import ElementTree as ET
18
19import attr
20
21from twisted.web.client import PartialDownloadError
22
23from synapse.api.errors import HttpResponseException
24from synapse.handlers.sso import MappingException, UserAttributes
25from synapse.http.site import SynapseRequest
26from synapse.types import UserID, map_username_to_mxid_localpart
27
28if TYPE_CHECKING:
29    from synapse.server import HomeServer
30
31logger = logging.getLogger(__name__)
32
33
34class CasError(Exception):
35    """Used to catch errors when validating the CAS ticket."""
36
37    def __init__(self, error: str, error_description: Optional[str] = None):
38        self.error = error
39        self.error_description = error_description
40
41    def __str__(self) -> str:
42        if self.error_description:
43            return f"{self.error}: {self.error_description}"
44        return self.error
45
46
47@attr.s(slots=True, frozen=True, auto_attribs=True)
48class CasResponse:
49    username: str
50    attributes: Dict[str, List[Optional[str]]]
51
52
53class CasHandler:
54    """
55    Utility class for to handle the response from a CAS SSO service.
56
57    Args:
58        hs
59    """
60
61    def __init__(self, hs: "HomeServer"):
62        self.hs = hs
63        self._hostname = hs.hostname
64        self._store = hs.get_datastore()
65        self._auth_handler = hs.get_auth_handler()
66        self._registration_handler = hs.get_registration_handler()
67
68        self._cas_server_url = hs.config.cas.cas_server_url
69        self._cas_service_url = hs.config.cas.cas_service_url
70        self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute
71        self._cas_required_attributes = hs.config.cas.cas_required_attributes
72
73        self._http_client = hs.get_proxied_http_client()
74
75        # identifier for the external_ids table
76        self.idp_id = "cas"
77
78        # user-facing name of this auth provider
79        self.idp_name = "CAS"
80
81        # we do not currently support brands/icons for CAS auth, but this is required by
82        # the SsoIdentityProvider protocol type.
83        self.idp_icon = None
84        self.idp_brand = None
85
86        self._sso_handler = hs.get_sso_handler()
87
88        self._sso_handler.register_identity_provider(self)
89
90    def _build_service_param(self, args: Dict[str, str]) -> str:
91        """
92        Generates a value to use as the "service" parameter when redirecting or
93        querying the CAS service.
94
95        Args:
96            args: Additional arguments to include in the final redirect URL.
97
98        Returns:
99            The URL to use as a "service" parameter.
100        """
101        return "%s?%s" % (
102            self._cas_service_url,
103            urllib.parse.urlencode(args),
104        )
105
106    async def _validate_ticket(
107        self, ticket: str, service_args: Dict[str, str]
108    ) -> CasResponse:
109        """
110        Validate a CAS ticket with the server, and return the parsed the response.
111
112        Args:
113            ticket: The CAS ticket from the client.
114            service_args: Additional arguments to include in the service URL.
115                Should be the same as those passed to `handle_redirect_request`.
116
117        Raises:
118            CasError: If there's an error parsing the CAS response.
119
120        Returns:
121            The parsed CAS response.
122        """
123        uri = self._cas_server_url + "/proxyValidate"
124        args = {
125            "ticket": ticket,
126            "service": self._build_service_param(service_args),
127        }
128        try:
129            body = await self._http_client.get_raw(uri, args)
130        except PartialDownloadError as pde:
131            # Twisted raises this error if the connection is closed,
132            # even if that's being used old-http style to signal end-of-data
133            body = pde.response
134        except HttpResponseException as e:
135            description = (
136                'Authorization server responded with a "{status}" error '
137                "while exchanging the authorization code."
138            ).format(status=e.code)
139            raise CasError("server_error", description) from e
140
141        return self._parse_cas_response(body)
142
143    def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
144        """
145        Retrieve the user and other parameters from the CAS response.
146
147        Args:
148            cas_response_body: The response from the CAS query.
149
150        Raises:
151            CasError: If there's an error parsing the CAS response.
152
153        Returns:
154            The parsed CAS response.
155        """
156
157        # Ensure the response is valid.
158        root = ET.fromstring(cas_response_body)
159        if not root.tag.endswith("serviceResponse"):
160            raise CasError(
161                "missing_service_response",
162                "root of CAS response is not serviceResponse",
163            )
164
165        success = root[0].tag.endswith("authenticationSuccess")
166        if not success:
167            raise CasError("unsucessful_response", "Unsuccessful CAS response")
168
169        # Iterate through the nodes and pull out the user and any extra attributes.
170        user = None
171        attributes: Dict[str, List[Optional[str]]] = {}
172        for child in root[0]:
173            if child.tag.endswith("user"):
174                user = child.text
175            if child.tag.endswith("attributes"):
176                for attribute in child:
177                    # ElementTree library expands the namespace in
178                    # attribute tags to the full URL of the namespace.
179                    # We don't care about namespace here and it will always
180                    # be encased in curly braces, so we remove them.
181                    tag = attribute.tag
182                    if "}" in tag:
183                        tag = tag.split("}")[1]
184                    attributes.setdefault(tag, []).append(attribute.text)
185
186        # Ensure a user was found.
187        if user is None:
188            raise CasError("no_user", "CAS response does not contain user")
189
190        return CasResponse(user, attributes)
191
192    async def handle_redirect_request(
193        self,
194        request: SynapseRequest,
195        client_redirect_url: Optional[bytes],
196        ui_auth_session_id: Optional[str] = None,
197    ) -> str:
198        """Generates a URL for the CAS server where the client should be redirected.
199
200        Args:
201            request: the incoming HTTP request
202            client_redirect_url: the URL that we should redirect the
203                client to after login (or None for UI Auth).
204            ui_auth_session_id: The session ID of the ongoing UI Auth (or
205                None if this is a login).
206
207        Returns:
208            URL to redirect to
209        """
210
211        if ui_auth_session_id:
212            service_args = {"session": ui_auth_session_id}
213        else:
214            assert client_redirect_url
215            service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
216
217        args = urllib.parse.urlencode(
218            {"service": self._build_service_param(service_args)}
219        )
220
221        return "%s/login?%s" % (self._cas_server_url, args)
222
223    async def handle_ticket(
224        self,
225        request: SynapseRequest,
226        ticket: str,
227        client_redirect_url: Optional[str],
228        session: Optional[str],
229    ) -> None:
230        """
231        Called once the user has successfully authenticated with the SSO.
232        Validates a CAS ticket sent by the client and completes the auth process.
233
234        If the user interactive authentication session is provided, marks the
235        UI Auth session as complete, then returns an HTML page notifying the
236        user they are done.
237
238        Otherwise, this registers the user if necessary, and then returns a
239        redirect (with a login token) to the client.
240
241        Args:
242            request: the incoming request from the browser. We'll
243                respond to it with a redirect or an HTML page.
244
245            ticket: The CAS ticket provided by the client.
246
247            client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
248                This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
249
250            session: The session parameter from the `/cas/ticket` HTTP request, if given.
251                This should be the UI Auth session id.
252        """
253        args = {}
254        if client_redirect_url:
255            args["redirectUrl"] = client_redirect_url
256        if session:
257            args["session"] = session
258
259        try:
260            cas_response = await self._validate_ticket(ticket, args)
261        except CasError as e:
262            logger.exception("Could not validate ticket")
263            self._sso_handler.render_error(request, e.error, e.error_description, 401)
264            return
265
266        await self._handle_cas_response(
267            request, cas_response, client_redirect_url, session
268        )
269
270    async def _handle_cas_response(
271        self,
272        request: SynapseRequest,
273        cas_response: CasResponse,
274        client_redirect_url: Optional[str],
275        session: Optional[str],
276    ) -> None:
277        """Handle a CAS response to a ticket request.
278
279        Assumes that the response has been validated. Maps the user onto an MXID,
280        registering them if necessary, and returns a response to the browser.
281
282        Args:
283            request: the incoming request from the browser. We'll respond to it with an
284                HTML page or a redirect
285
286            cas_response: The parsed CAS response.
287
288            client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
289                This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
290
291            session: The session parameter from the `/cas/ticket` HTTP request, if given.
292                This should be the UI Auth session id.
293        """
294
295        # first check if we're doing a UIA
296        if session:
297            return await self._sso_handler.complete_sso_ui_auth_request(
298                self.idp_id,
299                cas_response.username,
300                session,
301                request,
302            )
303
304        # otherwise, we're handling a login request.
305
306        # Ensure that the attributes of the logged in user meet the required
307        # attributes.
308        if not self._sso_handler.check_required_attributes(
309            request, cas_response.attributes, self._cas_required_attributes
310        ):
311            return
312
313        # Call the mapper to register/login the user
314
315        # If this not a UI auth request than there must be a redirect URL.
316        assert client_redirect_url is not None
317
318        try:
319            await self._complete_cas_login(cas_response, request, client_redirect_url)
320        except MappingException as e:
321            logger.exception("Could not map user")
322            self._sso_handler.render_error(request, "mapping_error", str(e))
323
324    async def _complete_cas_login(
325        self,
326        cas_response: CasResponse,
327        request: SynapseRequest,
328        client_redirect_url: str,
329    ) -> None:
330        """
331        Given a CAS response, complete the login flow
332
333        Retrieves the remote user ID, registers the user if necessary, and serves
334        a redirect back to the client with a login-token.
335
336        Args:
337            cas_response: The parsed CAS response.
338            request: The request to respond to
339            client_redirect_url: The redirect URL passed in by the client.
340
341        Raises:
342            MappingException if there was a problem mapping the response to a user.
343            RedirectException: some mapping providers may raise this if they need
344                to redirect to an interstitial page.
345        """
346        # Note that CAS does not support a mapping provider, so the logic is hard-coded.
347        localpart = map_username_to_mxid_localpart(cas_response.username)
348
349        async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
350            """
351            Map from CAS attributes to user attributes.
352            """
353            # Due to the grandfathering logic matching any previously registered
354            # mxids it isn't expected for there to be any failures.
355            if failures:
356                raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
357
358            # Arbitrarily use the first attribute found.
359            display_name = cas_response.attributes.get(
360                self._cas_displayname_attribute, [None]
361            )[0]
362
363            return UserAttributes(localpart=localpart, display_name=display_name)
364
365        async def grandfather_existing_users() -> Optional[str]:
366            # Since CAS did not always use the user_external_ids table, always
367            # to attempt to map to existing users.
368            user_id = UserID(localpart, self._hostname).to_string()
369
370            logger.debug(
371                "Looking for existing account based on mapped %s",
372                user_id,
373            )
374
375            users = await self._store.get_users_by_id_case_insensitive(user_id)
376            if users:
377                registered_user_id = list(users.keys())[0]
378                logger.info("Grandfathering mapping to %s", registered_user_id)
379                return registered_user_id
380
381            return None
382
383        await self._sso_handler.complete_sso_login_request(
384            self.idp_id,
385            cas_response.username,
386            request,
387            client_redirect_url,
388            cas_response_to_user_attributes,
389            grandfather_existing_users,
390        )
391