1# Copyright 2016 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""wsgi implement behaviour that provides service control as wsgi
16middleware.
17
18It provides the :class:`Middleware`, which is a WSGI middleware implementation
19that wraps another WSGI application to uses a provided
20:class:`google.api.control.client.Client` to provide service control.
21
22"""
23#pylint: disable=too-many-arguments
24
25from __future__ import absolute_import
26
27from datetime import datetime
28import httplib
29import logging
30import os
31import socket
32import uuid
33import urllib2
34import urlparse
35import wsgiref.util
36
37from google.api.auth import suppliers, tokens
38from . import check_request, messages, report_request, service
39
40
41logger = logging.getLogger(__name__)
42
43
44_CONTENT_LENGTH = 'content-length'
45_DEFAULT_LOCATION = 'global'
46
47_METADATA_SERVER_URL = 'http://metadata.google.internal'
48
49
50def _running_on_gce():
51  headers = {'Metadata-Flavor': 'Google'}
52
53  try:
54    request = urllib2.Request(_METADATA_SERVER_URL, headers=headers)
55    response = urllib2.urlopen(request)
56    if response.info().getheader('Metadata-Flavor') == 'Google':
57      return True
58  except (urllib2.URLError, socket.error):
59    pass
60
61  return False
62
63
64def _get_platform():
65  server_software = os.environ.get('SERVER_SOFTWARE', '')
66
67  if server_software.startswith('Development'):
68    return report_request.ReportedPlatforms.DEVELOPMENT
69  elif os.environ.get('KUBERNETES_SERVICE_HOST'):
70    return report_request.ReportedPlatforms.GKE
71  elif _running_on_gce():
72    # We're either in GAE Flex or GCE
73    if os.environ.get('GAE_MODULE_NAME'):
74      return report_request.ReportedPlatforms.GAE_FLEX
75    else:
76      return report_request.ReportedPlatforms.GCE
77  elif os.environ.get('GAE_MODULE_NAME'):
78    return report_request.ReportedPlatforms.GAE_STANDARD
79
80  return report_request.ReportedPlatforms.UNKNOWN
81
82
83platform = _get_platform()
84
85
86def running_on_devserver():
87  return platform == report_request.ReportedPlatforms.DEVELOPMENT
88
89
90def add_all(application, project_id, control_client,
91            loader=service.Loaders.FROM_SERVICE_MANAGEMENT):
92    """Adds all endpoints middleware to a wsgi application.
93
94    Sets up application to use all default endpoints middleware.
95
96    Example:
97
98      >>> application = MyWsgiApp()  # an existing WSGI application
99      >>>
100      >>> # the name of the controlled service
101      >>> service_name = 'my-service-name'
102      >>>
103      >>> # A GCP project  with service control enabled
104      >>> project_id = 'my-project-id'
105      >>>
106      >>> # wrap the app for service control
107      >>> from google.api.control import wsgi
108      >>> control_client = client.Loaders.DEFAULT.load(service_name)
109      >>> control_client.start()
110      >>> wrapped_app = add_all(application, project_id, control_client)
111      >>>
112      >>> # now use wrapped_app in place of app
113
114    Args:
115       application: the wrapped wsgi application
116       project_id: the project_id thats providing service control support
117       control_client: the service control client instance
118       loader (:class:`google.api.control.service.Loader`): loads the service
119          instance that configures this instance's behaviour
120    """
121    a_service = loader.load()
122    if not a_service:
123        raise ValueError("Failed to load service config")
124    authenticator = _create_authenticator(a_service)
125
126    wrapped_app = Middleware(application, project_id, control_client)
127    if authenticator:
128        wrapped_app = AuthenticationMiddleware(wrapped_app, authenticator)
129    return EnvironmentMiddleware(wrapped_app, a_service)
130
131
132def _next_operation_uuid():
133    return uuid.uuid4().hex
134
135
136class EnvironmentMiddleware(object):
137    """A WSGI middleware that sets related variables in the environment.
138
139    It attempts to add the following vars:
140
141    - google.api.config.service
142    - google.api.config.service_name
143    - google.api.config.method_registry
144    - google.api.config.reporting_rules
145    - google.api.config.method_info
146    """
147    # pylint: disable=too-few-public-methods
148
149    SERVICE = 'google.api.config.service'
150    SERVICE_NAME = 'google.api.config.service_name'
151    METHOD_REGISTRY = 'google.api.config.method_registry'
152    METHOD_INFO = 'google.api.config.method_info'
153    REPORTING_RULES = 'google.api.config.reporting_rules'
154
155    def __init__(self, application, a_service):
156        """Initializes a new Middleware instance.
157
158        Args:
159          application: the wrapped wsgi application
160          a_service (:class:`google.api.gen.servicecontrol_v1_messages.Service`):
161            a service instance
162        """
163        if not isinstance(a_service, messages.Service):
164            raise ValueError("service is None or not an instance of Service")
165
166        self._application = application
167        self._service = a_service
168
169        method_registry, reporting_rules = self._configure()
170        self._method_registry = method_registry
171        self._reporting_rules = reporting_rules
172
173    def _configure(self):
174        registry = service.MethodRegistry(self._service)
175        logs, metric_names, label_names = service.extract_report_spec(self._service)
176        reporting_rules = report_request.ReportingRules.from_known_inputs(
177            logs=logs,
178            metric_names=metric_names,
179            label_names=label_names)
180
181        return registry, reporting_rules
182
183    def __call__(self, environ, start_response):
184        environ[self.SERVICE] = self._service
185        environ[self.SERVICE_NAME] = self._service.name
186        environ[self.METHOD_REGISTRY] = self._method_registry
187        environ[self.REPORTING_RULES] = self._reporting_rules
188        parsed_uri = urlparse.urlparse(wsgiref.util.request_uri(environ))
189        http_method = environ.get('REQUEST_METHOD')
190        method_info = self._method_registry.lookup(http_method, parsed_uri.path)
191        if method_info:
192            environ[self.METHOD_INFO] = method_info
193
194        return self._application(environ, start_response)
195
196
197class Middleware(object):
198    """A WSGI middleware implementation that provides service control.
199
200    Example:
201
202      >>> app = MyWsgiApp()  # an existing WSGI application
203      >>>
204      >>> # the name of the controlled service
205      >>> service_name = 'my-service-name'
206      >>>
207      >>> # A GCP project  with service control enabled
208      >>> project_id = 'my-project-id'
209      >>>
210      >>> # wrap the app for service control
211      >>> from google.api.control import client, wsgi, service
212      >>> control_client = client.Loaders.DEFAULT.load(service_name)
213      >>> control_client.start()
214      >>> wrapped_app = wsgi.Middleware(app, control_client, project_id)
215      >>> env_app = wsgi.EnvironmentMiddleware(wrapped,app)
216      >>>
217      >>> # now use env_app in place of app
218
219    """
220    # pylint: disable=too-few-public-methods, fixme
221    _NO_API_KEY_MSG = (
222        'Method does not allow callers without established identity.'
223        ' Please use an API key or other form of API consumer identity'
224        ' to call this API.'
225     )
226
227    def __init__(self,
228                 application,
229                 project_id,
230                 control_client,
231                 next_operation_id=_next_operation_uuid,
232                 timer=datetime.utcnow):
233        """Initializes a new Middleware instance.
234
235        Args:
236           application: the wrapped wsgi application
237           project_id: the project_id thats providing service control support
238           control_client: the service control client instance
239           next_operation_id (func): produces the next operation
240           timer (func[[datetime.datetime]]): a func that obtains the current time
241           """
242        self._application = application
243        self._project_id = project_id
244        self._next_operation_id = next_operation_id
245        self._control_client = control_client
246        self._timer = timer
247
248    def __call__(self, environ, start_response):
249        # pylint: disable=too-many-locals
250        method_info = environ.get(EnvironmentMiddleware.METHOD_INFO)
251        if not method_info:
252            # just allow the wrapped application to handle the request
253            logger.debug('method_info not present in the wsgi environment'
254                         ', no service control')
255            return self._application(environ, start_response)
256
257        latency_timer = _LatencyTimer(self._timer)
258        latency_timer.start()
259
260        # Determine if the request can proceed
261        http_method = environ.get('REQUEST_METHOD')
262        parsed_uri = urlparse.urlparse(wsgiref.util.request_uri(environ))
263        app_info = _AppInfo()
264        # TODO: determine if any of the more complex ways of getting the request size
265        # (e.g) buffering and counting the wsgi input stream is more appropriate here
266        try:
267            app_info.request_size = int(environ.get('CONTENT_LENGTH',
268                                                    report_request.SIZE_NOT_SET))
269        except ValueError:
270            logger.warn('ignored bad content-length: %s', environ.get('CONTENT_LENGTH'))
271
272        app_info.http_method = http_method
273        app_info.url = parsed_uri
274
275        check_info = self._create_check_info(method_info, parsed_uri, environ)
276        if not check_info.api_key and not method_info.allow_unregistered_calls:
277            logger.debug("skipping %s, no api key was provided", parsed_uri)
278            error_msg = self._handle_missing_api_key(app_info, start_response)
279        else:
280            check_req = check_info.as_check_request()
281            logger.debug('checking %s with %s', method_info, check_request)
282            check_resp = self._control_client.check(check_req)
283            error_msg = self._handle_check_response(app_info, check_resp, start_response)
284
285        if error_msg:
286            # send a report request that indicates that the request failed
287            rules = environ.get(EnvironmentMiddleware.REPORTING_RULES)
288            latency_timer.end()
289            report_req = self._create_report_request(method_info,
290                                                     check_info,
291                                                     app_info,
292                                                     latency_timer,
293                                                     rules)
294            logger.debug('scheduling report_request %s', report_req)
295            self._control_client.report(report_req)
296            return error_msg
297
298        # update the client with the response
299        latency_timer.app_start()
300
301        # run the application request in an inner handler that sets the status
302        # and response code on app_info
303        def inner_start_response(status, response_headers, exc_info=None):
304            app_info.response_code = int(status.partition(' ')[0])
305            for name, value in response_headers:
306                if name.lower() == _CONTENT_LENGTH:
307                    app_info.response_size = int(value)
308                    break
309            return start_response(status, response_headers, exc_info)
310
311        result = self._application(environ, inner_start_response)
312
313        # perform reporting, result must be joined otherwise the latency record
314        # is incorrect
315        result = b''.join(result)
316        latency_timer.end()
317        app_info.response_size = len(result)
318        rules = environ.get(EnvironmentMiddleware.REPORTING_RULES)
319        report_req = self._create_report_request(method_info,
320                                                 check_info,
321                                                 app_info,
322                                                 latency_timer,
323                                                 rules)
324        logger.debug('scheduling report_request %s', report_req)
325        self._control_client.report(report_req)
326        return result
327
328    def _create_report_request(self,
329                               method_info,
330                               check_info,
331                               app_info,
332                               latency_timer,
333                               reporting_rules):
334        # TODO: determine how to obtain the consumer_project_id and the location
335        # correctly
336        report_info = report_request.Info(
337            api_key=check_info.api_key,
338            api_key_valid=app_info.api_key_valid,
339            api_method=method_info.selector,
340            consumer_project_id=self._project_id,  # TODO: see above
341            location=_DEFAULT_LOCATION,  # TODO: see above
342            method=app_info.http_method,
343            operation_id=check_info.operation_id,
344            operation_name=check_info.operation_name,
345            backend_time=latency_timer.backend_time,
346            overhead_time=latency_timer.overhead_time,
347            platform=platform,
348            producer_project_id=self._project_id,
349            protocol=report_request.ReportedProtocols.HTTP,
350            request_size=app_info.request_size,
351            request_time=latency_timer.request_time,
352            response_code=app_info.response_code,
353            response_size=app_info.response_size,
354            referer=check_info.referer,
355            service_name=check_info.service_name,
356            url=app_info.url
357        )
358        return report_info.as_report_request(reporting_rules, timer=self._timer)
359
360    def _create_check_info(self, method_info, parsed_uri, environ):
361        service_name = environ.get(EnvironmentMiddleware.SERVICE_NAME)
362        operation_id = self._next_operation_id()
363        api_key_valid = False
364        api_key = _find_api_key_param(method_info, parsed_uri)
365        if not api_key:
366            api_key = _find_api_key_header(method_info, environ)
367        if not api_key:
368            api_key = _find_default_api_key_param(parsed_uri)
369
370        if api_key:
371            api_key_valid = True
372
373        check_info = check_request.Info(
374            api_key=api_key,
375            api_key_valid=api_key_valid,
376            client_ip=environ.get('REMOTE_ADDR', ''),
377            consumer_project_id=self._project_id,  # TODO: switch this to producer_project_id
378            operation_id=operation_id,
379            operation_name=method_info.selector,
380            referer=environ.get('HTTP_REFERER', ''),
381            service_name=service_name
382        )
383        return check_info
384
385    def _handle_check_response(self, app_info, check_resp, start_response):
386        code, detail, api_key_valid = check_request.convert_response(
387            check_resp, self._project_id)
388        if code == httplib.OK:
389            return None  # the check was OK
390
391        # there was problem; the request cannot proceed
392        logger.warn('Check failed %d, %s', code, detail)
393        error_msg = '%d %s' % (code, detail)
394        start_response(error_msg, [])
395        app_info.response_code = code
396        app_info.api_key_valid = api_key_valid
397        return error_msg  # the request cannot continue
398
399    def _handle_missing_api_key(self, app_info, start_response):
400        code = httplib.UNAUTHORIZED
401        detail = self._NO_API_KEY_MSG
402        logger.warn('Check not performed %d, %s', code, detail)
403        error_msg = '%d %s' % (code, detail)
404        start_response(error_msg, [])
405        app_info.response_code = code
406        app_info.api_key_valid = False
407        return error_msg  # the request cannot continue
408
409
410class _AppInfo(object):
411    # pylint: disable=too-few-public-methods
412
413    def __init__(self):
414        self.api_key_valid = True
415        self.response_code = httplib.INTERNAL_SERVER_ERROR
416        self.response_size = report_request.SIZE_NOT_SET
417        self.request_size = report_request.SIZE_NOT_SET
418        self.http_method = None
419        self.url = None
420
421
422class _LatencyTimer(object):
423
424    def __init__(self, timer):
425        self._timer = timer
426        self._start = None
427        self._app_start = None
428        self._end = None
429
430    def start(self):
431        self._start = self._timer()
432
433    def app_start(self):
434        self._app_start = self._timer()
435
436    def end(self):
437        self._end = self._timer()
438        if self._app_start is None:
439            self._app_start = self._end
440
441    @property
442    def request_time(self):
443        if self._start and self._end:
444            return self._end - self._start
445        return None
446
447    @property
448    def overhead_time(self):
449        if self._start and self._app_start:
450            return self._app_start - self._start
451        return None
452
453    @property
454    def backend_time(self):
455        if self._end and self._app_start:
456            return self._end - self._app_start
457        return None
458
459
460def _find_api_key_param(info, parsed_uri):
461    params = info.api_key_url_query_params
462    if not params:
463        return None
464
465    param_dict = urlparse.parse_qs(parsed_uri.query)
466    if not param_dict:
467        return None
468
469    for q in params:
470        value = param_dict.get(q)
471        if value:
472            # param's values are lists, assume the first value
473            # is what's needed
474            return value[0]
475
476    return None
477
478
479_DEFAULT_API_KEYS = ('key', 'api_key')
480
481
482def _find_default_api_key_param(parsed_uri):
483    param_dict = urlparse.parse_qs(parsed_uri.query)
484    if not param_dict:
485        return None
486
487    for q in _DEFAULT_API_KEYS:
488        value = param_dict.get(q)
489        if value:
490            # param's values are lists, assume the first value
491            # is what's needed
492            return value[0]
493
494    return None
495
496
497def _find_api_key_header(info, environ):
498    headers = info.api_key_http_header
499    if not headers:
500        return None
501
502    for h in headers:
503        value = environ.get('HTTP_' + h.upper())
504        if value:
505            return value  # headers have single values
506
507    return None
508
509def _create_authenticator(a_service):
510    """Create an instance of :class:`google.auth.tokens.Authenticator`.
511
512    Args:
513      a_service (:class:`google.api.gen.servicecontrol_v1_messages.Service`): a
514        service instance
515    """
516    if not isinstance(a_service, messages.Service):
517        raise ValueError("service is None or not an instance of Service")
518
519    authentication = a_service.authentication
520    if not authentication:
521        logger.info("authentication is not configured in service, "
522                    "authentication checks will be disabled")
523        return
524
525    issuers_to_provider_ids = {}
526    issuer_uri_configs = {}
527    for provider in authentication.providers:
528        issuer = provider.issuer
529        jwks_uri = provider.jwksUri
530
531        # Enable openID discovery if jwks_uri is unset
532        open_id = jwks_uri is None
533        issuer_uri_configs[issuer] = suppliers.IssuerUriConfig(open_id, jwks_uri)
534        issuers_to_provider_ids[issuer] = provider.id
535
536    key_uri_supplier = suppliers.KeyUriSupplier(issuer_uri_configs)
537    jwks_supplier = suppliers.JwksSupplier(key_uri_supplier)
538    authenticator = tokens.Authenticator(issuers_to_provider_ids, jwks_supplier)
539    return authenticator
540
541
542class AuthenticationMiddleware(object):
543    """A WSGI middleware that does authentication checks for incoming
544    requests.
545
546    In environments where os.environ is replaced with a request-local and
547    thread-independent copy (e.g. Google Appengine), authentication result is
548    added to os.environ so that the wrapped application can make use of the
549    authentication result.
550    """
551    # pylint: disable=too-few-public-methods
552
553    USER_INFO = "google.api.auth.user_info"
554
555    def __init__(self, application, authenticator):
556        """Initializes an authentication middleware instance.
557
558        Args:
559          application: a WSGI application to be wrapped
560          authenticator (:class:`google.auth.tokens.Authenticator`): an
561            authenticator that authenticates incoming requests
562        """
563        if not isinstance(authenticator, tokens.Authenticator):
564            raise ValueError("Invalid authenticator")
565
566        self._application = application
567        self._authenticator = authenticator
568
569    def __call__(self, environ, start_response):
570        method_info = environ.get(EnvironmentMiddleware.METHOD_INFO)
571        if not method_info or not method_info.auth_info:
572            # No authentication configuration for this method
573            logger.debug("authentication is not configured")
574            return self._application(environ, start_response)
575
576        auth_token = _extract_auth_token(environ)
577        user_info = None
578        if not auth_token:
579            logger.debug("No auth token is attached to the request")
580        else:
581            try:
582                service_name = environ.get(EnvironmentMiddleware.SERVICE_NAME)
583                user_info = self._authenticator.authenticate(auth_token,
584                                                             method_info.auth_info,
585                                                             service_name)
586            except Exception:  # pylint: disable=broad-except
587                logger.debug("Cannot decode and verify the auth token. The backend "
588                             "will not be able to retrieve user info", exc_info=True)
589
590        environ[self.USER_INFO] = user_info
591
592        # pylint: disable=protected-access
593        if user_info and not isinstance(os.environ, os._Environ):
594            # Set user info into os.environ only if os.environ is replaced
595            # with a request-local copy
596            os.environ[self.USER_INFO] = user_info
597
598        response = self._application(environ, start_response)
599
600        # Erase user info from os.environ for safety and sanity.
601        if self.USER_INFO in os.environ:
602            del os.environ[self.USER_INFO]
603
604        return response
605
606
607_ACCESS_TOKEN_PARAM_NAME = "access_token"
608_BEARER_TOKEN_PREFIX = "Bearer "
609_BEARER_TOKEN_PREFIX_LEN = len(_BEARER_TOKEN_PREFIX)
610
611
612def _extract_auth_token(environ):
613    # First try to extract auth token from HTTP authorization header.
614    auth_header = environ.get("HTTP_AUTHORIZATION")
615    if auth_header:
616        if auth_header.startswith(_BEARER_TOKEN_PREFIX):
617            return auth_header[_BEARER_TOKEN_PREFIX_LEN:]
618        return
619
620    # Then try to read auth token from query.
621    parameters = urlparse.parse_qs(environ.get("QUERY_STRING", ""))
622    if _ACCESS_TOKEN_PARAM_NAME in parameters:
623        auth_token, = parameters[_ACCESS_TOKEN_PARAM_NAME]
624        return auth_token
625