1import hashlib
2import hmac
3import logging
4import time
5import cgi
6import six
7
8from six.moves.urllib.parse import quote, parse_qs
9from six.moves.http_cookies import SimpleCookie
10
11from saml2 import BINDING_HTTP_ARTIFACT
12from saml2 import BINDING_HTTP_REDIRECT
13from saml2 import BINDING_HTTP_POST
14from saml2 import BINDING_URI
15from saml2 import BINDING_SOAP
16from saml2 import SAMLError
17from saml2 import time_util
18
19__author__ = 'rohe0002'
20
21logger = logging.getLogger(__name__)
22
23
24class Response(object):
25    _template = None
26    _status = '200 OK'
27    _content_type = 'text/html'
28    _mako_template = None
29    _mako_lookup = None
30
31    def __init__(self, message=None, **kwargs):
32        self.status = kwargs.get('status', self._status)
33        self.response = kwargs.get('response', self._response)
34        self.template = kwargs.get('template', self._template)
35        self.mako_template = kwargs.get('mako_template', self._mako_template)
36        self.mako_lookup = kwargs.get('template_lookup', self._mako_lookup)
37
38        self.message = message
39
40        self.headers = kwargs.get('headers', [])
41        _content_type = kwargs.get('content', self._content_type)
42        addContentType = True
43        for header in self.headers:
44            if 'content-type' == header[0].lower():
45                addContentType = False
46        if addContentType:
47            self.headers.append(('Content-type', _content_type))
48
49    def __call__(self, environ, start_response, **kwargs):
50        try:
51            start_response(self.status, self.headers)
52        except TypeError:
53            pass
54        return self.response(self.message or geturl(environ), **kwargs)
55
56    def _response(self, message="", **argv):
57        if self.template:
58            message = self.template % message
59        elif self.mako_lookup and self.mako_template:
60            argv["message"] = message
61            mte = self.mako_lookup.get_template(self.mako_template)
62            message = mte.render(**argv)
63
64        if isinstance(message, six.string_types):
65            return [message.encode('utf-8')]
66        elif isinstance(message, six.binary_type):
67            return [message]
68        else:
69            return message
70
71    def add_header(self, ava):
72        """
73        Does *NOT* replace a header of the same type, just adds a new
74        :param ava: (type, value) tuple
75        """
76        self.headers.append(ava)
77
78    def reply(self, **kwargs):
79        return self.response(self.message, **kwargs)
80
81
82class Created(Response):
83    _status = "201 Created"
84
85
86class Redirect(Response):
87    _template = '<html>\n<head><title>Redirecting to %s</title></head>\n' \
88                '<body>\nYou are being redirected to <a href="%s">%s</a>\n' \
89                '</body>\n</html>'
90    _status = '302 Found'
91
92    def __call__(self, environ, start_response, **kwargs):
93        location = self.message
94        self.headers.append(('location', location))
95        start_response(self.status, self.headers)
96        return self.response((location, location, location))
97
98
99class SeeOther(Response):
100    _template = '<html>\n<head><title>Redirecting to %s</title></head>\n' \
101                '<body>\nYou are being redirected to <a href="%s">%s</a>\n' \
102                '</body>\n</html>'
103    _status = '303 See Other'
104
105    def __call__(self, environ, start_response, **kwargs):
106        location = ""
107        if self.message:
108            location = self.message
109            self.headers.append(('location', location))
110        else:
111            for param, item in self.headers:
112                if param == "location":
113                    location = item
114                    break
115        start_response(self.status, self.headers)
116        return self.response((location, location, location))
117
118
119class Forbidden(Response):
120    _status = '403 Forbidden'
121    _template = "<html>Not allowed to mess with: '%s'</html>"
122
123
124class BadRequest(Response):
125    _status = "400 Bad Request"
126    _template = "<html>%s</html>"
127
128
129class Unauthorized(Response):
130    _status = "401 Unauthorized"
131    _template = "<html>%s</html>"
132
133
134class NotFound(Response):
135    _status = '404 NOT FOUND'
136
137
138class NotAcceptable(Response):
139    _status = '406 Not Acceptable'
140
141
142class ServiceError(Response):
143    _status = '500 Internal Service Error'
144
145
146class NotImplemented(Response):
147    _status = "501 Not Implemented"
148    # override template since we need an environment variable
149    template = ('The request method %s is not implemented '
150                'for this server.\r\n%s')
151
152
153class BadGateway(Response):
154    _status = "502 Bad Gateway"
155
156
157class HttpParameters(object):
158    """GET or POST signature parameters for Redirect or POST-SimpleSign bindings
159    because they are not contained in XML unlike the POST binding
160    """
161    signature = None
162    sigalg = None
163    # Relaystate and SAML message are stored elsewhere
164    def __init__(self, dict):
165        try:
166            self.signature = dict["Signature"][0]
167            self.sigalg = dict["SigAlg"][0]
168        except KeyError:
169            pass
170
171
172def extract(environ, empty=False, err=False):
173    """Extracts strings in form data and returns a dict.
174
175    :param environ: WSGI environ
176    :param empty: Stops on empty fields (default: Fault)
177    :param err: Stops on errors in fields (default: Fault)
178    """
179    formdata = cgi.parse(environ['wsgi.input'], environ, empty, err)
180    # Remove single entries from lists
181    for key, value in iter(formdata.items()):
182        if len(value) == 1:
183            formdata[key] = value[0]
184    return formdata
185
186
187def geturl(environ, query=True, path=True, use_server_name=False):
188    """Rebuilds a request URL (from PEP 333).
189    You may want to chose to use the environment variables
190    server_name and server_port instead of http_host in some case.
191    The parameter use_server_name allows you to chose.
192
193    :param query: Is QUERY_STRING included in URI (default: True)
194    :param path: Is path included in URI (default: True)
195    :param use_server_name: If SERVER_NAME/_HOST should be used instead of
196        HTTP_HOST
197    """
198    url = [environ['wsgi.url_scheme'] + '://']
199    if use_server_name:
200        url.append(environ['SERVER_NAME'])
201        if environ['wsgi.url_scheme'] == 'https':
202            if environ['SERVER_PORT'] != '443':
203                url.append(':' + environ['SERVER_PORT'])
204        else:
205            if environ['SERVER_PORT'] != '80':
206                url.append(':' + environ['SERVER_PORT'])
207    else:
208        url.append(environ['HTTP_HOST'])
209    if path:
210        url.append(getpath(environ))
211    if query and environ.get('QUERY_STRING'):
212        url.append('?' + environ['QUERY_STRING'])
213    return ''.join(url)
214
215
216def getpath(environ):
217    """Builds a path."""
218    return ''.join([quote(environ.get('SCRIPT_NAME', '')),
219                    quote(environ.get('PATH_INFO', ''))])
220
221
222def get_post(environ):
223    # the environment variable CONTENT_LENGTH may be empty or missing
224    try:
225        request_body_size = int(environ.get('CONTENT_LENGTH', 0))
226    except ValueError:
227        request_body_size = 0
228
229    # When the method is POST the query string will be sent
230    # in the HTTP request body which is passed by the WSGI server
231    # in the file like wsgi.input environment variable.
232    return environ['wsgi.input'].read(request_body_size)
233
234
235def get_response(environ, start_response):
236    if environ.get("REQUEST_METHOD") == "GET":
237        query = environ.get("QUERY_STRING")
238    elif environ.get("REQUEST_METHOD") == "POST":
239        query = get_post(environ)
240    else:
241        resp = BadRequest("Unsupported method")
242        return resp(environ, start_response)
243
244    return query
245
246
247def unpack_redirect(environ):
248    if "QUERY_STRING" in environ:
249        _qs = environ["QUERY_STRING"]
250        return dict([(k, v[0]) for k, v in parse_qs(_qs).items()])
251    else:
252        return None
253
254
255def unpack_post(environ):
256    return dict([(k, v[0]) for k, v in parse_qs(get_post(environ))])
257
258
259def unpack_soap(environ):
260    try:
261        query = get_post(environ)
262        return {"SAMLRequest": query, "RelayState": ""}
263    except Exception:
264        return None
265
266
267def unpack_artifact(environ):
268    if environ["REQUEST_METHOD"] == "GET":
269        _dict = unpack_redirect(environ)
270    elif environ["REQUEST_METHOD"] == "POST":
271        _dict = unpack_post(environ)
272    else:
273        _dict = None
274    return _dict
275
276
277def unpack_any(environ):
278    if environ['REQUEST_METHOD'].upper() == 'GET':
279        # Could be either redirect or artifact
280        _dict = unpack_redirect(environ)
281        if "ID" in _dict:
282            binding = BINDING_URI
283        elif "SAMLart" in _dict:
284            binding = BINDING_HTTP_ARTIFACT
285        else:
286            binding = BINDING_HTTP_REDIRECT
287    else:
288        content_type = environ.get('CONTENT_TYPE', 'application/soap+xml')
289        if content_type != 'application/soap+xml':
290            # normal post
291            _dict = unpack_post(environ)
292            if "SAMLart" in _dict:
293                binding = BINDING_HTTP_ARTIFACT
294            else:
295                binding = BINDING_HTTP_POST
296        else:
297            _dict = unpack_soap(environ)
298            binding = BINDING_SOAP
299
300    return _dict, binding
301
302
303def _expiration(timeout, time_format=None):
304    if timeout == "now":
305        return time_util.instant(time_format)
306    else:
307        # validity time should match lifetime of assertions
308        return time_util.in_a_while(minutes=timeout, format=time_format)
309
310
311def cookie_signature(seed, *parts):
312    """Generates a cookie signature."""
313    sha1 = hmac.new(seed, digestmod=hashlib.sha1)
314    for part in parts:
315        if part:
316            sha1.update(part)
317    return sha1.hexdigest()
318
319
320def make_cookie(name, load, seed, expire=0, domain="", path="",
321                timestamp=""):
322    """
323    Create and return a cookie
324
325    :param name: Cookie name
326    :param load: Cookie load
327    :param seed: A seed for the HMAC function
328    :param expire: Number of minutes before this cookie goes stale
329    :param domain: The domain of the cookie
330    :param path: The path specification for the cookie
331    :return: A tuple to be added to headers
332    """
333    cookie = SimpleCookie()
334    if not timestamp:
335        timestamp = str(int(time.mktime(time.gmtime())))
336    signature = cookie_signature(seed, load, timestamp)
337    cookie[name] = "|".join([load, timestamp, signature])
338    if path:
339        cookie[name]["path"] = path
340    if domain:
341        cookie[name]["domain"] = domain
342    if expire:
343        cookie[name]["expires"] = _expiration(expire,
344                                              "%a, %d-%b-%Y %H:%M:%S GMT")
345
346    return tuple(cookie.output().split(": ", 1))
347
348
349def parse_cookie(name, seed, kaka):
350    """Parses and verifies a cookie value
351
352    :param seed: A seed used for the HMAC signature
353    :param kaka: The cookie
354    :return: A tuple consisting of (payload, timestamp)
355    """
356    if not kaka:
357        return None
358
359    cookie_obj = SimpleCookie(kaka)
360    morsel = cookie_obj.get(name)
361
362    if morsel:
363        parts = morsel.value.split("|")
364        if len(parts) != 3:
365            return None
366            # verify the cookie signature
367        sig = cookie_signature(seed, parts[0], parts[1])
368        if sig != parts[2]:
369            raise SAMLError("Invalid cookie signature")
370
371        try:
372            return parts[0].strip(), parts[1]
373        except KeyError:
374            return None
375    else:
376        return None
377
378
379def cookie_parts(name, kaka):
380    cookie_obj = SimpleCookie(kaka)
381    morsel = cookie_obj.get(name)
382    if morsel:
383        return morsel.value.split("|")
384    else:
385        return None
386