1from flask import current_app, session
2from flask.signals import Namespace
3from ..base_client import FrameworkIntegration, OAuthError
4from ..requests_client import OAuth1Session, OAuth2Session
5
6_signal = Namespace()
7#: signal when token is updated
8token_update = _signal.signal('token_update')
9
10
11class FlaskIntegration(FrameworkIntegration):
12    oauth1_client_cls = OAuth1Session
13    oauth2_client_cls = OAuth2Session
14
15    def set_session_data(self, request, key, value):
16        sess_key = '_{}_authlib_{}_'.format(self.name, key)
17        session[sess_key] = value
18
19    def get_session_data(self, request, key):
20        sess_key = '_{}_authlib_{}_'.format(self.name, key)
21        return session.pop(sess_key, None)
22
23    def update_token(self, token, refresh_token=None, access_token=None):
24        token_update.send(
25            current_app,
26            name=self.name,
27            token=token,
28            refresh_token=refresh_token,
29            access_token=access_token,
30        )
31
32    def generate_access_token_params(self, request_token_url, request):
33        if request_token_url:
34            return request.args.to_dict(flat=True)
35
36        if request.method == 'GET':
37            error = request.args.get('error')
38            if error:
39                description = request.args.get('error_description')
40                raise OAuthError(error=error, description=description)
41
42            params = {
43                'code': request.args['code'],
44                'state': request.args.get('state'),
45            }
46        else:
47            params = {
48                'code': request.form['code'],
49                'state': request.form.get('state'),
50            }
51        return params
52
53    @staticmethod
54    def load_config(oauth, name, params):
55        rv = {}
56        for k in params:
57            conf_key = '{}_{}'.format(name, k).upper()
58            v = oauth.app.config.get(conf_key, None)
59            if v is not None:
60                rv[k] = v
61        return rv
62