1import weakref
2
3import pkg_resources
4
5from wsme.exc import ClientSideError
6
7
8__all__ = [
9    'CallContext',
10
11    'register_protocol', 'getprotocol',
12]
13
14registered_protocols = {}
15
16
17def _cfg(f):
18    cfg = getattr(f, '_cfg', None)
19    if cfg is None:
20        f._cfg = cfg = {}
21    return cfg
22
23
24class expose(object):
25    def __init__(self, path, content_type):
26        self.path = path
27        self.content_type = content_type
28
29    def __call__(self, func):
30        func.exposed = True
31        cfg = _cfg(func)
32        cfg['content-type'] = self.content_type
33        cfg.setdefault('paths', []).append(self.path)
34        return func
35
36
37class CallContext(object):
38    def __init__(self, request):
39        self._request = weakref.ref(request)
40        self.path = None
41
42        self.func = None
43        self.funcdef = None
44
45    @property
46    def request(self):
47        return self._request()
48
49
50class ObjectDict(object):
51    def __init__(self, obj):
52        self.obj = obj
53
54    def __getitem__(self, name):
55        return getattr(self.obj, name)
56
57
58class Protocol(object):
59    name = None
60    displayname = None
61    content_types = []
62
63    def resolve_path(self, path):
64        if '$' in path:
65            from string import Template
66            s = Template(path)
67            path = s.substitute(ObjectDict(self))
68        return path
69
70    def iter_routes(self):
71        for attrname in dir(self):
72            attr = getattr(self, attrname)
73            if getattr(attr, 'exposed', False):
74                for path in _cfg(attr)['paths']:
75                    yield self.resolve_path(path), attr
76
77    def accept(self, request):
78        return request.headers.get('Content-Type') in self.content_types
79
80    def iter_calls(self, request):
81        pass
82
83    def extract_path(self, context):
84        pass
85
86    def read_arguments(self, context):
87        pass
88
89    def encode_result(self, context, result):
90        pass
91
92    def encode_sample_value(self, datatype, value, format=False):
93        return ('none', 'N/A')
94
95    def encode_sample_params(self, params, format=False):
96        return ('none', 'N/A')
97
98    def encode_sample_result(self, datatype, value, format=False):
99        return ('none', 'N/A')
100
101
102def register_protocol(protocol):
103    registered_protocols[protocol.name] = protocol
104
105
106def getprotocol(name, **options):
107    protocol_class = registered_protocols.get(name)
108    if protocol_class is None:
109        for entry_point in pkg_resources.iter_entry_points(
110                'wsme.protocols', name):
111            if entry_point.name == name:
112                protocol_class = entry_point.load()
113        if protocol_class is None:
114            raise ValueError("Cannot find protocol '%s'" % name)
115        registered_protocols[name] = protocol_class
116    return protocol_class(**options)
117
118
119def media_type_accept(request, content_types):
120    """Validate media types against request.method.
121
122    When request.method is GET or HEAD compare with the Accept header.
123    When request.method is POST, PUT or PATCH compare with the Content-Type
124    header.
125    When request.method is DELETE media type is irrelevant, so return True.
126    """
127    if request.method in ['GET', 'HEAD']:
128        if request.accept:
129            if request.accept.acceptable_offers(content_types):
130                return True
131            error_message = ('Unacceptable Accept type: %s not in %s'
132                             % (request.accept, content_types))
133            raise ClientSideError(error_message, status_code=406)
134    elif request.method in ['PUT', 'POST', 'PATCH']:
135        content_type = request.headers.get('Content-Type')
136        if content_type:
137            for ct in content_types:
138                if request.headers.get('Content-Type', '').startswith(ct):
139                    return True
140            error_message = ('Unacceptable Content-Type: %s not in %s'
141                             % (content_type, content_types))
142            raise ClientSideError(error_message, status_code=415)
143        else:
144            raise ClientSideError('missing Content-Type header')
145    elif request.method in ['DELETE']:
146        return True
147    return False
148