1# -*- coding: utf-8 -*-
2# Copyright (c) 2015 Spotify AB
3
4from __future__ import absolute_import, division, print_function
5
6
7import json
8import logging
9import os
10import re
11import sys
12
13try:
14    from collections import OrderedDict
15except ImportError:  # NOCOV
16    from ordereddict import OrderedDict
17
18from six import iterkeys, iteritems
19import xmltodict
20
21from .parameters import (
22    Body, URIParameter, Header, FormParameter, QueryParameter
23)
24
25PYVER = sys.version_info[:3]
26
27if PYVER == (2, 7, 9) or PYVER == (3, 4, 3):  # NOCOV
28    import six.moves.urllib.request as urllib
29    import six.moves.urllib.error as urllib_error
30    URLLIB = True
31    SECURE_DOWNLOAD = True
32else:
33    try:  # NOCOV
34        import requests
35        URLLIB = False
36        SECURE_DOWNLOAD = True
37    except ImportError:
38        import six.moves.urllib.request as urllib
39        import six.moves.urllib.error as urllib_error
40        URLLIB = True
41        SECURE_DOWNLOAD = False
42
43from .errors import MediaTypeError
44
45
46IANA_URL = "https://www.iana.org/assignments/media-types/media-types.xml"
47
48
49def load_schema(data):
50    """
51    Load Schema/Example data depending on its type (JSON, XML).
52
53    If error in parsing as JSON and XML, just returns unloaded data.
54
55    :param str data: schema/example data
56    """
57    try:
58        return json.loads(data)
59    except Exception:  # POKEMON!
60        pass
61
62    try:
63        return xmltodict.parse(data)
64    except Exception:  # GOTTA CATCH THEM ALL
65        pass
66
67    return data
68
69
70def setup_logger(key):
71    """General logger"""
72    log = logging.getLogger(__name__)
73    log.setLevel(logging.DEBUG)
74    console = logging.StreamHandler()
75    console.setLevel(logging.DEBUG)
76    msg = "{key} - %(levelname)s - %(message)s".format(key=key)
77    formatter = logging.Formatter(msg)
78    console.setFormatter(formatter)
79
80    log.addHandler(console)
81    return log
82
83
84def _requests_download(url):
85    """Download a URL using ``requests`` library"""
86    try:
87        response = requests.get(url)
88        return response.text
89    except requests.exceptions.RequestException as e:
90        msg = "Error downloading from {0}: {1}".format(url, e)
91        raise MediaTypeError(msg)
92
93
94def _urllib_download(url):
95    """Download a URL using ``urllib`` library"""
96    try:
97        response = urllib.urlopen(url)
98    except urllib_error.URLError as e:
99        msg = "Error downloading from {0}: {1}".format(url, e)
100        raise MediaTypeError(msg)
101    return response.read()
102
103
104def download_url(url):
105    """
106    General download function, given a URL.
107
108    If running 2.7.8 or earlier, or 3.4.2 or earlier, then use
109    ``requests`` if it's installed.  Otherwise, use ``urllib``.
110    """
111    log = setup_logger("DOWNLOAD")
112    if SECURE_DOWNLOAD and not URLLIB:
113        return _requests_download(url)
114    elif SECURE_DOWNLOAD and URLLIB:
115        return _urllib_download(url)
116    msg = ("Downloading over HTTPS but can not verify the host's "
117           "certificate.  To avoid this in the future, `pip install"
118           " \"requests[security]\"`.")
119    log.warn(msg)
120    return _urllib_download(url)
121
122
123def _xml_to_dict(response_text):
124    """Parse XML response from IANA into a Python ``dict``."""
125    try:
126        return xmltodict.parse(response_text)
127    except xmltodict.expat.ExpatError as e:
128        msg = "Error parsing XML: {0}".format(e)
129        raise MediaTypeError(msg)
130
131
132def _extract_mime_types(registry):
133    """
134    Parse out MIME types from a defined registry (e.g. "application",
135    "audio", etc).
136    """
137    mime_types = []
138    records = registry.get("record", {})
139    reg_name = registry.get("@id")
140    for rec in records:
141        mime = rec.get("file", {}).get("#text")
142        if mime:
143            mime_types.append(mime)
144        else:
145            mime = rec.get("name")
146            if mime:
147                hacked_mime = reg_name + "/" + mime
148                mime_types.append(hacked_mime)
149    return mime_types
150
151
152def _parse_xml_data(xml_data):
153    """Parse the given XML data."""
154    registries = xml_data.get("registry", {}).get("registry")
155    if not registries:
156        msg = "No registries found to parse."
157        raise MediaTypeError(msg)
158    if len(registries) is not 9:
159        msg = ("Expected 9 registries but parsed "
160               "{0}".format(len(registries)))
161        raise MediaTypeError(msg)
162    all_mime_types = []
163    for registry in registries:
164        mime_types = _extract_mime_types(registry)
165        all_mime_types.extend(mime_types)
166
167    return all_mime_types
168
169
170def _save_updated_mime_types(output_file, mime_types):
171    """Save the updated MIME Media types within the package."""
172    with open(output_file, "w") as f:
173        json.dump(mime_types, f)
174
175
176def update_mime_types():
177    """
178    Update MIME Media Types from IANA.  Requires internet connection.
179    """
180    log = setup_logger("UPDATE")
181
182    log.debug("Getting XML data from IANA")
183    raw_data = download_url(IANA_URL)
184    log.debug("Data received; parsing...")
185    xml_data = _xml_to_dict(raw_data)
186    mime_types = _parse_xml_data(xml_data)
187
188    current_dir = os.path.dirname(os.path.realpath(__file__))
189    data_dir = os.path.join(current_dir, "data")
190    output_file = os.path.join(data_dir, "supported_mime_types.json")
191
192    _save_updated_mime_types(output_file, mime_types)
193
194    log.debug("Done! Supported IANA MIME media types have been updated.")
195
196
197def _resource_type_lookup(assigned, root):
198    """
199    Returns ``ResourceType`` object
200
201    :param str assigned: The string name of the assigned resource type
202    :param root: RAML root object
203    """
204    res_types = root.resource_types
205    if res_types:
206        res_type_obj = [r for r in res_types if r.name == assigned]
207        if res_type_obj:
208            return res_type_obj[0]
209
210
211#####
212# Helper methods
213#####
214
215# general
216def _get(data, item, default=None):
217    """
218    Helper function to catch empty mappings in RAML. If item is optional
219    but not in the data, or data is ``None``, the default value is returned.
220
221    :param data: RAML data
222    :param str item: RAML key
223    :param default: default value if item is not in dict
224    :param bool optional: If RAML item is optional or needs to be defined
225    :ret: value for RAML key
226    """
227    try:
228        return data.get(item, default)
229    except AttributeError:
230        return default
231
232
233def _create_base_param_obj(attribute_data, param_obj, config, errors, **kw):
234    """Helper function to create a BaseParameter object"""
235    objects = []
236
237    for key, value in list(iteritems(attribute_data)):
238        if param_obj is URIParameter:
239            required = _get(value, "required", default=True)
240        else:
241            required = _get(value, "required", default=False)
242        kwargs = dict(
243            name=key,
244            raw={key: value},
245            desc=_get(value, "description"),
246            display_name=_get(value, "displayName", key),
247            min_length=_get(value, "minLength"),
248            max_length=_get(value, "maxLength"),
249            minimum=_get(value, "minimum"),
250            maximum=_get(value, "maximum"),
251            default=_get(value, "default"),
252            enum=_get(value, "enum"),
253            example=_get(value, "example"),
254            required=required,
255            repeat=_get(value, "repeat", False),
256            pattern=_get(value, "pattern"),
257            type=_get(value, "type", "string"),
258            config=config,
259            errors=errors
260        )
261        if param_obj is Header:
262            kwargs["method"] = _get(kw, "method")
263
264        item = param_obj(**kwargs)
265        objects.append(item)
266
267    return objects or None
268
269
270# used for traits & resource nodes
271def _map_param_unparsed_str_obj(param):
272    return {
273        "queryParameters": QueryParameter,
274        "uriParameters": URIParameter,
275        "formParameters": FormParameter,
276        "baseUriParameters": URIParameter,
277        "headers": Header
278    }[param]
279
280
281# create_resource_types
282def _get_union(resource, method, inherited):
283    union = {}
284    for key, value in list(iteritems(inherited)):
285        if resource.get(method) is not None:
286            if key not in list(iterkeys(resource.get(method, {}))):
287                union[key] = value
288            else:
289                resource_values = resource.get(method, {}).get(key)
290                inherited_values = inherited.get(key, {})
291                union[key] = dict(list(iteritems(resource_values)) +
292                                  list(iteritems(inherited_values)))
293    if resource.get(method) is not None:
294        for key, value in list(iteritems(resource.get(method, {}))):
295            if key not in list(iterkeys(inherited)):
296                union[key] = value
297    return union
298
299
300def __is_scalar(item):
301    scalar_props = [
302        "type", "enum", "pattern", "minLength", "maxLength",
303        "minimum", "maximum", "example", "repeat", "required",
304        "default", "description", "usage", "schema", "example",
305        "displayName"
306    ]
307    return item in scalar_props
308
309
310def __get_sets(child, parent):
311    child_keys = []
312    parent_keys = []
313    if child:
314        child_keys = list(iterkeys(child))
315    if parent:
316        parent_keys = list(iterkeys(parent))
317    child_diff = list(set(child_keys) - set(parent_keys))
318    parent_diff = list(set(parent_keys) - set(child_keys))
319    intersection = list(set(child_keys).intersection(parent_keys))
320    opt_inters = [i for i in child_keys if str(i) + "?" in parent_keys]
321    intersection = intersection + opt_inters
322
323    return child, parent, child_diff, parent_diff, intersection
324
325
326def _get_data_union(child, parent):
327    # takes child data and parent data and merges them
328    # with preference to child data overwriting parent data
329    # confession: had to look up set theory!
330    # FIXME: should bring this over from config, not hard code
331    methods = [
332        'get', 'post', 'put', 'delete', 'patch', 'head', 'options',
333        'trace', 'connect', 'get?', 'post?', 'put?', 'delete?', 'patch?',
334        'head?', 'options?', 'trace?', 'connect?'
335    ]
336    union = {}
337    child, parent, c_diff, p_diff, inters = __get_sets(child, parent)
338
339    for i in c_diff:
340        union[i] = child.get(i)
341    for i in p_diff:
342        if i in methods and not i.endswith("?"):
343                union[i] = parent.get(i)
344        if i not in methods:
345            union[i] = parent.get(i)
346    for i in inters:
347        if __is_scalar(i):
348            union[i] = child.get(i)
349        else:
350            _child = child.get(i, {})
351            _parent = parent.get(i, {})
352            union[i] = _get_data_union(_child, _parent)
353    return union
354
355
356def _get_inherited_resource(res_name, resource_types):
357    for resource in resource_types:
358        if res_name == list(iterkeys(resource))[0]:
359            return resource
360
361
362def _get_res_type_attribute(res_data, method_data, item, default={}):
363    method_level = _get(method_data, item, default)
364    resource_level = _get(res_data, item, default)
365    return method_level, resource_level
366
367
368def _get_inherited_type_params(data, attribute, params, resource_types):
369    inherited = _get_inherited_resource(data.get("type"), resource_types)
370    inherited = inherited.get(data.get("type"))
371    inherited_params = inherited.get(attribute, {})
372
373    return dict(list(iteritems(params)) +
374                list(iteritems(inherited_params)))
375
376
377def _get_inherited_item(items, item_name, res_types, meth_, _data):
378    inherited = _get_inherited_resource(_data.get("type"), res_types)
379    resource = inherited.get(_data.get("type"))
380    res_level = resource.get(meth_, {}).get(item_name, {})
381
382    method_ = resource.get(meth_, {})
383    method_level = method_.get(item_name, {})
384    items = dict(
385        list(iteritems(items)) +
386        list(iteritems(res_level)) +
387        list(iteritems(method_level))
388    )
389    return items
390
391
392def _get_attribute_dict(data, item, v):
393    resource_level = _get(v, item, {})
394    method_level = _get(data, item, {})
395    return dict(list(iteritems(resource_level)) +
396                list(iteritems(method_level)))
397
398
399def _map_attr(attribute):
400    return {
401        "mediaType": "media_type",
402        "protocols": "protocols",
403        "headers": "headers",
404        "body": "body",
405        "responses": "responses",
406        "uriParameters": "uri_params",
407        "baseUriParameters": "base_uri_params",
408        "queryParameters": "query_params",
409        "formParameters": "form_params",
410        "description": "description"
411    }[attribute]
412
413
414# create_node
415def _get_method(attribute, method, raw_data):
416    """Returns ``attribute`` defined at the method level, or ``None``."""
417    # if method is not None:  # must explicitly say `not None`
418    #     get_attribute = raw_data.get(method, {})
419    #     if get_attribute is not None:
420    #         return get_attribute.get(attribute, {})
421    # return {}
422    # if method is not None:
423    ret = _get(raw_data, method, {})
424    ret = _get(ret, attribute, {})
425    return ret
426
427
428def _get_resource(attribute, raw_data):
429    """Returns ``attribute`` defined at the resource level, or ``None``."""
430    return raw_data.get(attribute, {})
431
432
433def _get_parent(attribute, parent):
434    if parent:
435        return getattr(parent, attribute, {})
436    return {}
437
438
439# needs/uses parsed raml data
440def _get_resource_type(attribute, root, type_, method):
441    """Returns ``attribute`` defined in the resource type, or ``None``."""
442    if type_ and root.resource_types:
443        types = root.resource_types
444        r_type = [r for r in types if r.name == type_]
445        r_type = [r for r in r_type if r.method == method]
446        if r_type:
447            if hasattr(r_type[0], attribute):
448                if getattr(r_type[0], attribute) is not None:
449                    return getattr(r_type[0], attribute)
450    return []
451
452
453def _get_trait(attribute, root, is_):
454    """Returns ``attribute`` defined in a trait, or ``None``."""
455
456    if is_:
457        traits = root.traits
458        if traits:
459            trait_objs = []
460            for i in is_:
461                trait = [t for t in traits if t.name == i]
462                if trait:
463                    if hasattr(trait[0], attribute):
464                        if getattr(trait[0], attribute) is not None:
465                            trait_objs.extend(getattr(trait[0], attribute))
466            return trait_objs
467    return []
468
469
470def _get_scheme(item, root):
471    schemes = root.raw.get("securitySchemes", [])
472    for s in schemes:
473        if isinstance(item, str):
474            if item == list(iterkeys(s))[0]:
475                return s
476        elif isinstance(item, dict):
477            if list(iterkeys(item))[0] == list(iterkeys(s))[0]:
478                return s
479
480
481def _get_attribute(attribute, method, raw_data):
482    method_level = _get_method(attribute, method, raw_data)
483    resource_level = _get_resource(attribute, raw_data)
484    return OrderedDict(list(iteritems(method_level)) +
485                       list(iteritems(resource_level)))
486
487
488def _get_inherited_attribute(attribute, root, type_, method, is_):
489    type_objects = _get_resource_type(attribute, root, type_, method)
490    trait_objects = _get_trait(attribute, root, is_)
491    return type_objects + trait_objects
492
493
494# TODO: refactor - this ain't pretty
495def _remove_duplicates(inherit_params, resource_params):
496    ret = []
497    if isinstance(resource_params[0], Body):
498        _params = [p.mime_type for p in resource_params]
499    else:
500        _params = [p.name for p in resource_params]
501
502    for p in inherit_params:
503        if isinstance(p, Body):
504            if p.mime_type not in _params:
505                ret.append(p)
506        else:
507            if p.name not in _params:
508                ret.append(p)
509    ret.extend(resource_params)
510    return ret or None
511
512
513def _map_inheritance(nodetype):
514    return {
515        "traits": __trait,
516        "types": __resource_type,
517        "method": __method,
518        "resource": __resource,
519        "parent": __parent,
520        "root": __root
521    }[nodetype]
522
523
524def __trait(item, **kwargs):
525    root = kwargs.get("root")
526    is_ = kwargs.get("is_")
527    return _get_trait(item, root, is_)
528
529
530def __resource_type(item, **kwargs):
531    root = kwargs.get("root")
532    type_ = kwargs.get("type_")
533    method = kwargs.get("method")
534    item = _map_attr(item)
535    return _get_resource_type(item, root, type_, method)
536
537
538def __method(item, **kwargs):
539    method = kwargs.get("method")
540    data = kwargs.get("data")
541    return _get_method(item, method, data)
542
543
544def __resource(item, **kwargs):
545    data = kwargs.get("data")
546    return _get_resource(item, data)
547
548
549def __parent(item, **kwargs):
550    parent = kwargs.get("parent")
551    return _get_parent(item, parent)
552
553
554def __root(item, **kwargs):
555    root = kwargs.get("root")
556    item = _map_attr(item)
557    return getattr(root, item, None)
558
559
560def get_inherited(item, inherit_from=[], **kwargs):
561    ret = {}
562    for nodetype in inherit_from:
563        inherit_func = _map_inheritance(nodetype)
564        inherited = inherit_func(item, **kwargs)
565        ret[nodetype] = inherited
566    return ret
567
568
569#####
570# set uri, form, query, header objects for traits
571#####
572
573def set_param_object(param_data, param_str, root):
574    params = _get(param_data, param_str, {})
575    param_obj = _map_param_unparsed_str_obj(param_str)
576    return _create_base_param_obj(params,
577                                  param_obj,
578                                  root.config,
579                                  root.errors)
580
581
582#####
583# set query, form, uri params for resource nodes
584#####
585
586# <--[uri]-->
587def __create_params(unparsed, parsed, method, raw_data, root, cls, type_, is_):
588    _params = _get_attribute(unparsed, method, raw_data)
589    param_objs = _get_inherited_attribute(parsed, root, type_,
590                                          method, is_)
591    params = _create_base_param_obj(_params, cls, root.config, root.errors)
592    return params, param_objs
593
594
595def _create_uri_params(unparsed, parsed, root_params, root, type_, is_,
596                       method, raw_data, parent):
597    params, param_objs = __create_params(unparsed, parsed, method, raw_data,
598                                         root, URIParameter, type_, is_)
599
600    if params:
601        param_objs.extend(params)
602    if parent and parent.uri_params:
603        param_objs.extend(parent.uri_params)
604    if root_params:
605        param_names = [p.name for p in param_objs]
606        _params = [p for p in root_params if p.name not in param_names]
607        param_objs.extend(_params)
608    return param_objs or None
609# <--[uri]-->
610
611
612# <--[query, base uri, form]-->
613def _check_already_exists(param, ret_list):
614    if isinstance(param, Body):
615        param_name_list = [p.mime_type for p in ret_list]
616        if param.mime_type not in param_name_list:
617            ret_list.append(param)
618            param_name_list.append(param.mime_type)
619
620    else:
621        param_name_list = [p.name for p in ret_list]
622        if param.name not in param_name_list:
623            ret_list.append(param)
624            param_name_list.append(param.name)
625    return ret_list
626
627
628# TODO: refactor - this ain't pretty
629def __remove_duplicates(to_clean):
630    # order: resource, inherited, parent, root
631    ret = []
632
633    for param_set in to_clean:
634        if param_set:
635            for p in param_set:
636                ret = _check_already_exists(p, ret)
637    return ret or None
638
639
640def _map_parsed_str(parsed):
641    name = parsed.split("_")[:-1]
642    name.append("parameters")
643    name = [n.capitalize() for n in name]
644    name = "".join(name)
645    return name[0].lower() + name[1:]
646
647
648def set_params(data, param_str, root, method, inherit=False, **kw):
649    params, param_objs, parent_params, root_params = [], [], [], []
650
651    unparsed = _map_parsed_str(param_str)
652    param_obj = _map_param_unparsed_str_obj(unparsed)
653    _params = _get_attribute(unparsed, method, data)
654
655    params = _create_base_param_obj(_params, param_obj,
656                                    root.config, root.errors)
657    if params is None:
658        params = []
659
660    # inherited objects
661    if inherit:
662        type_ = kw.get("type_")
663        is_ = kw.get("is_")
664        param_objs = _get_inherited_attribute(param_str, root, type_,
665                                              method, is_)
666
667    # parent objects
668    parent = kw.get("parent")
669    if parent:
670        parent_params = getattr(parent, param_str, [])
671
672    # root objects
673    root = kw.get("root_params")
674    if root:
675        param_names = [p.name for p in param_objs]
676        root_params = [p for p in root if p.name not in param_names]
677
678    to_clean = (params, param_objs, parent_params, root_params)
679
680    return __remove_duplicates(to_clean)
681# <--[query, base uri, form]-->
682
683
684# preserve order of URI and Base URI parameters
685# used for RootNode, ResourceNode
686def _preserve_uri_order(path, param_objs, config, errors, declared=[]):
687    # if this is hit, RAML shouldn't be valid anyways.
688    if isinstance(path, list):
689        path = path[0]
690
691    sorted_params = []
692    pattern = "\{(.*?)\}"
693    params = re.findall(pattern, path)
694    if not param_objs:
695        param_objs = []
696    # if there are URI parameters in the path but were not declared
697    # inline, we should create them.
698    # TODO: Probably shouldn't do it in this function, though...
699    if len(params) > len(param_objs):
700        if len(param_objs) > 0:
701            param_names = [p.name for p in param_objs]
702            missing = [p for p in params if p not in param_names]
703        else:
704            missing = params[::]
705        # exclude any (base)uri params if already declared
706        missing = [p for p in missing if p not in declared]
707        for m in missing:
708            # no need to create a URI param for version
709            if m == "version":
710                continue
711            data = {"type": "string"}
712            _param = URIParameter(name=m,
713                                  raw={m: data},
714                                  required=True,
715                                  display_name=m,
716                                  desc=_get(data, "description"),
717                                  min_length=_get(data, "minLength"),
718                                  max_length=_get(data, "maxLength"),
719                                  minimum=_get(data, "minimum"),
720                                  maximum=_get(data, "maximum"),
721                                  default=_get(data, "default"),
722                                  enum=_get(data, "enum"),
723                                  example=_get(data, "example"),
724                                  repeat=_get(data, "repeat", False),
725                                  pattern=_get(data, "pattern"),
726                                  type=_get(data, "type", "string"),
727                                  config=config,
728                                  errors=errors)
729            param_objs.append(_param)
730    for p in params:
731        _param = [i for i in param_objs if i.name == p]
732        if _param:
733            sorted_params.append(_param[0])
734    return sorted_params or None
735