1import weakref 2 3import pkg_resources 4 5from wsme.exc import ClientSideError 6 7 8__all__ = [ 9 'CallContext', 10 11 'register_protocol', 'getprotocol', 12] 13 14registered_protocols = {} 15 16 17def _cfg(f): 18 cfg = getattr(f, '_cfg', None) 19 if cfg is None: 20 f._cfg = cfg = {} 21 return cfg 22 23 24class expose(object): 25 def __init__(self, path, content_type): 26 self.path = path 27 self.content_type = content_type 28 29 def __call__(self, func): 30 func.exposed = True 31 cfg = _cfg(func) 32 cfg['content-type'] = self.content_type 33 cfg.setdefault('paths', []).append(self.path) 34 return func 35 36 37class CallContext(object): 38 def __init__(self, request): 39 self._request = weakref.ref(request) 40 self.path = None 41 42 self.func = None 43 self.funcdef = None 44 45 @property 46 def request(self): 47 return self._request() 48 49 50class ObjectDict(object): 51 def __init__(self, obj): 52 self.obj = obj 53 54 def __getitem__(self, name): 55 return getattr(self.obj, name) 56 57 58class Protocol(object): 59 name = None 60 displayname = None 61 content_types = [] 62 63 def resolve_path(self, path): 64 if '$' in path: 65 from string import Template 66 s = Template(path) 67 path = s.substitute(ObjectDict(self)) 68 return path 69 70 def iter_routes(self): 71 for attrname in dir(self): 72 attr = getattr(self, attrname) 73 if getattr(attr, 'exposed', False): 74 for path in _cfg(attr)['paths']: 75 yield self.resolve_path(path), attr 76 77 def accept(self, request): 78 return request.headers.get('Content-Type') in self.content_types 79 80 def iter_calls(self, request): 81 pass 82 83 def extract_path(self, context): 84 pass 85 86 def read_arguments(self, context): 87 pass 88 89 def encode_result(self, context, result): 90 pass 91 92 def encode_sample_value(self, datatype, value, format=False): 93 return ('none', 'N/A') 94 95 def encode_sample_params(self, params, format=False): 96 return ('none', 'N/A') 97 98 def encode_sample_result(self, datatype, value, format=False): 99 return ('none', 'N/A') 100 101 102def register_protocol(protocol): 103 registered_protocols[protocol.name] = protocol 104 105 106def getprotocol(name, **options): 107 protocol_class = registered_protocols.get(name) 108 if protocol_class is None: 109 for entry_point in pkg_resources.iter_entry_points( 110 'wsme.protocols', name): 111 if entry_point.name == name: 112 protocol_class = entry_point.load() 113 if protocol_class is None: 114 raise ValueError("Cannot find protocol '%s'" % name) 115 registered_protocols[name] = protocol_class 116 return protocol_class(**options) 117 118 119def media_type_accept(request, content_types): 120 """Validate media types against request.method. 121 122 When request.method is GET or HEAD compare with the Accept header. 123 When request.method is POST, PUT or PATCH compare with the Content-Type 124 header. 125 When request.method is DELETE media type is irrelevant, so return True. 126 """ 127 if request.method in ['GET', 'HEAD']: 128 if request.accept: 129 if request.accept.acceptable_offers(content_types): 130 return True 131 error_message = ('Unacceptable Accept type: %s not in %s' 132 % (request.accept, content_types)) 133 raise ClientSideError(error_message, status_code=406) 134 elif request.method in ['PUT', 'POST', 'PATCH']: 135 content_type = request.headers.get('Content-Type') 136 if content_type: 137 for ct in content_types: 138 if request.headers.get('Content-Type', '').startswith(ct): 139 return True 140 error_message = ('Unacceptable Content-Type: %s not in %s' 141 % (content_type, content_types)) 142 raise ClientSideError(error_message, status_code=415) 143 else: 144 raise ClientSideError('missing Content-Type header') 145 elif request.method in ['DELETE']: 146 return True 147 return False 148