1# encoding=utf8
2
3import unittest
4
5from wsme import WSRoot
6from wsme.protocol import getprotocol, CallContext, Protocol
7import wsme.protocol
8
9
10class DummyProtocol(Protocol):
11    name = 'dummy'
12    content_types = ['', None]
13
14    def __init__(self):
15        self.hits = 0
16
17    def accept(self, req):
18        return True
19
20    def iter_calls(self, req):
21        yield CallContext(req)
22
23    def extract_path(self, context):
24        return ['touch']
25
26    def read_arguments(self, context):
27        self.lastreq = context.request
28        self.hits += 1
29        return {}
30
31    def encode_result(self, context, result):
32        return str(result)
33
34    def encode_error(self, context, infos):
35        return str(infos)
36
37
38def test_getprotocol():
39    try:
40        getprotocol('invalid')
41        assert False, "ValueError was not raised"
42    except ValueError:
43        pass
44
45
46class TestProtocols(unittest.TestCase):
47    def test_register_protocol(self):
48        wsme.protocol.register_protocol(DummyProtocol)
49        assert wsme.protocol.registered_protocols['dummy'] == DummyProtocol
50
51        r = WSRoot()
52        assert len(r.protocols) == 0
53
54        r.addprotocol('dummy')
55        assert len(r.protocols) == 1
56        assert r.protocols[0].__class__ == DummyProtocol
57
58        r = WSRoot(['dummy'])
59        assert len(r.protocols) == 1
60        assert r.protocols[0].__class__ == DummyProtocol
61
62    def test_Protocol(self):
63        p = wsme.protocol.Protocol()
64        assert p.iter_calls(None) is None
65        assert p.extract_path(None) is None
66        assert p.read_arguments(None) is None
67        assert p.encode_result(None, None) is None
68        assert p.encode_sample_value(None, None) == ('none', 'N/A')
69        assert p.encode_sample_params(None) == ('none', 'N/A')
70        assert p.encode_sample_result(None, None) == ('none', 'N/A')
71