1
2from hyper import HTTP20Connection
3from hyper.tls import init_context
4
5import jwt
6
7# For creating and comparing the time for the JWT token
8import time
9
10
11DEFAULT_TOKEN_LIFETIME = 3600
12DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256'
13
14
15# Abstract Base class. This should not be instantiated directly.
16class Credentials(object):
17
18    def __init__(self, ssl_context=None):
19        self.__ssl_context = ssl_context
20
21    # Creates a connection with the credentials, if available or necessary.
22    def create_connection(self, server, port, proto):
23        # self.__ssl_context may be none, and that's fine.
24        return HTTP20Connection(server, port,
25                                ssl_context=self.__ssl_context,
26                                force_proto=proto or 'h2')
27
28    def get_authorization_header(self, topic):
29        return None
30
31
32# Credentials subclass for certificate authentication
33class CertificateCredentials(Credentials):
34    def __init__(self, cert_file=None, password=None, cert_chain=None):
35        ssl_context = init_context(cert=cert_file, cert_password=password)
36        if cert_chain:
37            ssl_context.load_cert_chain(cert_chain)
38        super(CertificateCredentials, self).__init__(ssl_context)
39
40
41# Credentials subclass for JWT token based authentication
42class TokenCredentials(Credentials):
43    def __init__(self, auth_key_path, auth_key_id, team_id,
44                 encryption_algorithm=None, token_lifetime=None):
45        self.__auth_key = self._get_signing_key(auth_key_path)
46        self.__auth_key_id = auth_key_id
47        self.__team_id = team_id
48        self.__encryption_algorithm = DEFAULT_TOKEN_ENCRYPTION_ALGORITHM if \
49            encryption_algorithm is None else \
50            encryption_algorithm
51        self.__token_lifetime = DEFAULT_TOKEN_LIFETIME if \
52            token_lifetime is None else token_lifetime
53
54        # Dictionary of {topic: (issue time, ascii decoded token)}
55        self.__topicTokens = {}
56
57        # Use the default constructor because we don't have an SSL context
58        super(TokenCredentials, self).__init__()
59
60    def get_tokens(self):
61        return [val[1] for val in self.__topicTokens]
62
63    def get_authorization_header(self, topic):
64        token = self._get_or_create_topic_token(topic)
65        return 'bearer %s' % token
66
67    def _isExpiredToken(self, issueDate):
68        now = time.time()
69        return now < issueDate + DEFAULT_TOKEN_LIFETIME
70
71    def _get_or_create_topic_token(self, topic):
72        # dict of topic to issue date and JWT token
73        tokenPair = self.__topicTokens.get(topic)
74        if tokenPair is None or self._isExpiredToken(tokenPair[0]):
75            # Create a new token
76            issuedAt = time.time()
77            tokenDict = {'iss': self.__team_id,
78                         'iat': issuedAt
79                         }
80            headers = {'alg': self.__encryption_algorithm,
81                       'kid': self.__auth_key_id,
82                       }
83            jwtToken = jwt.encode(tokenDict, self.__auth_key,
84                                  algorithm=self.__encryption_algorithm,
85                                  headers=headers).decode('ascii')
86
87            self.__topicTokens[topic] = (issuedAt, jwtToken)
88            return jwtToken
89        else:
90            return tokenPair[1]
91
92    def _get_signing_key(self, key_path):
93        secret = ''
94        if key_path:
95            with open(key_path) as f:
96                secret = f.read()
97        return secret
98