1import txrestapi
2__package__="txrestapi"
3import re
4import os.path
5from twisted.internet import reactor
6from twisted.internet.defer import inlineCallbacks
7from twisted.web.resource import Resource, NoResource
8from twisted.web.server import Request, Site
9from twisted.web.client import getPage
10from twisted.trial import unittest
11from .resource import APIResource
12from .methods import GET, PUT
13
14class FakeChannel(object):
15    transport = None
16
17def getRequest(method, url):
18    req = Request(FakeChannel(), None)
19    req.method = method
20    req.path = url
21    return req
22
23class APIResourceTest(unittest.TestCase):
24
25    def test_returns_normal_resources(self):
26        r = APIResource()
27        a = Resource()
28        r.putChild('a', a)
29        req = Request(FakeChannel(), None)
30        a_ = r.getChild('a', req)
31        self.assertEqual(a, a_)
32
33    def test_registry(self):
34        compiled = re.compile('regex')
35        r = APIResource()
36        r.register('GET', 'regex', None)
37        self.assertEqual([x[0] for x in r._registry], ['GET'])
38        self.assertEqual(r._registry[0], ('GET', compiled, None))
39
40    def test_method_matching(self):
41        r = APIResource()
42        r.register('GET', 'regex', 1)
43        r.register('PUT', 'regex', 2)
44        r.register('GET', 'another', 3)
45
46        req = getRequest('GET', 'regex')
47        result = r._get_callback(req)
48        self.assert_(result)
49        self.assertEqual(result[0], 1)
50
51        req = getRequest('PUT', 'regex')
52        result = r._get_callback(req)
53        self.assert_(result)
54        self.assertEqual(result[0], 2)
55
56        req = getRequest('GET', 'another')
57        result = r._get_callback(req)
58        self.assert_(result)
59        self.assertEqual(result[0], 3)
60
61        req = getRequest('PUT', 'another')
62        result = r._get_callback(req)
63        self.assertEqual(result, (None, None))
64
65    def test_callback(self):
66        marker = object()
67        def cb(request):
68            return marker
69        r = APIResource()
70        r.register('GET', 'regex', cb)
71        req = getRequest('GET', 'regex')
72        result = r.getChild('regex', req)
73        self.assertEqual(result.render(req), marker)
74
75    def test_longerpath(self):
76        marker = object()
77        r = APIResource()
78        def cb(request):
79            return marker
80        r.register('GET', '/regex/a/b/c', cb)
81        req = getRequest('GET', '/regex/a/b/c')
82        result = r.getChild('regex', req)
83        self.assertEqual(result.render(req), marker)
84
85    def test_args(self):
86        r = APIResource()
87        def cb(request, **kwargs):
88            return kwargs
89        r.register('GET', '/(?P<a>[^/]*)/a/(?P<b>[^/]*)/c', cb)
90        req = getRequest('GET', '/regex/a/b/c')
91        result = r.getChild('regex', req)
92        self.assertEqual(sorted(result.render(req).keys()), ['a', 'b'])
93
94    def test_order(self):
95        r = APIResource()
96        def cb1(request, **kwargs):
97            kwargs.update({'cb1':True})
98            return kwargs
99        def cb(request, **kwargs):
100            return kwargs
101        # Register two regexes that will match
102        r.register('GET', '/(?P<a>[^/]*)/a/(?P<b>[^/]*)/c', cb1)
103        r.register('GET', '/(?P<a>[^/]*)/a/(?P<b>[^/]*)', cb)
104        req = getRequest('GET', '/regex/a/b/c')
105        result = r.getChild('regex', req)
106        # Make sure the first one got it
107        self.assert_('cb1' in result.render(req))
108
109    def test_no_resource(self):
110        r = APIResource()
111        r.register('GET', '^/(?P<a>[^/]*)/a/(?P<b>[^/]*)$', None)
112        req = getRequest('GET', '/definitely/not/a/match')
113        result = r.getChild('regex', req)
114        self.assert_(isinstance(result, NoResource))
115
116    def test_all(self):
117        r = APIResource()
118        def get_cb(r): return 'GET'
119        def put_cb(r): return 'PUT'
120        def all_cb(r): return 'ALL'
121        r.register('GET', '^path', get_cb)
122        r.register('ALL', '^path', all_cb)
123        r.register('PUT', '^path', put_cb)
124        # Test that the ALL registration picks it up before the PUT one
125        for method in ('GET', 'PUT', 'ALL'):
126            req = getRequest(method, 'path')
127            result = r.getChild('path', req)
128            self.assertEqual(result.render(req), 'ALL' if method=='PUT' else method)
129
130
131class TestResource(Resource):
132    isLeaf = True
133    def render(self, request):
134        return 'aresource'
135
136
137class TestAPI(APIResource):
138
139    @GET('^/(?P<a>test[^/]*)/?')
140    def _on_test_get(self, request, a):
141        return 'GET %s' % a
142
143    @PUT('^/(?P<a>test[^/]*)/?')
144    def _on_test_put(self, request, a):
145        return 'PUT %s' % a
146
147    @GET('^/gettest')
148    def _on_gettest(self, request):
149        return TestResource()
150
151
152class DecoratorsTest(unittest.TestCase):
153    def _listen(self, site):
154        return reactor.listenTCP(0, site, interface="127.0.0.1")
155
156    def setUp(self):
157        r = TestAPI()
158        site = Site(r, timeout=None)
159        self.port = self._listen(site)
160        self.portno = self.port.getHost().port
161
162    def tearDown(self):
163        return self.port.stopListening()
164
165    def getURL(self, path):
166        return "http://127.0.0.1:%d/%s" % (self.portno, path)
167
168    @inlineCallbacks
169    def test_get(self):
170        url = self.getURL('test_thing/')
171        result = yield getPage(url, method='GET')
172        self.assertEqual(result, 'GET test_thing')
173
174    @inlineCallbacks
175    def test_put(self):
176        url = self.getURL('test_thing/')
177        result = yield getPage(url, method='PUT')
178        self.assertEqual(result, 'PUT test_thing')
179
180    @inlineCallbacks
181    def test_resource_wrapper(self):
182        url = self.getURL('gettest')
183        result = yield getPage(url, method='GET')
184        self.assertEqual(result, 'aresource')
185
186
187def test_suite():
188    import unittest as ut
189    suite = unittest.TestSuite()
190    suite.addTest(ut.makeSuite(DecoratorsTest))
191    suite.addTest(ut.makeSuite(APIResourceTest))
192    suite.addTest(unittest.doctest.DocFileSuite(os.path.join('..', 'README.rst')))
193    return suite
194
195