1#!/usr/bin/env python
2
3try:
4    import asyncio
5except ImportError:
6    import trollius as asyncio
7import unittest
8import socket
9import sys
10
11import aiodns
12import pycares
13
14
15class DNSTest(unittest.TestCase):
16
17    def setUp(self):
18        self.loop = asyncio.new_event_loop()
19        self.addCleanup(self.loop.close)
20        self.resolver = aiodns.DNSResolver(loop=self.loop)
21
22    def tearDown(self):
23        self.resolver = None
24
25    def test_query_a(self):
26        f = self.resolver.query('google.com', 'A')
27        result = self.loop.run_until_complete(f)
28        self.assertTrue(result)
29
30    def test_query_a_bad(self):
31        f = self.resolver.query('hgf8g2od29hdohid.com', 'A')
32        try:
33            self.loop.run_until_complete(f)
34        except aiodns.error.DNSError as e:
35            self.assertEqual(e.args[0], aiodns.error.ARES_ENOTFOUND)
36
37    def test_query_aaaa(self):
38        f = self.resolver.query('ipv6.google.com', 'AAAA')
39        result = self.loop.run_until_complete(f)
40        self.assertTrue(result)
41
42    def test_query_cname(self):
43        f = self.resolver.query('livechat.ripe.net', 'CNAME')
44        result = self.loop.run_until_complete(f)
45        self.assertTrue(result)
46
47    def test_query_mx(self):
48        f = self.resolver.query('google.com', 'MX')
49        result = self.loop.run_until_complete(f)
50        self.assertTrue(result)
51
52    def test_query_ns(self):
53        f = self.resolver.query('google.com', 'NS')
54        result = self.loop.run_until_complete(f)
55        self.assertTrue(result)
56
57    def test_query_txt(self):
58        f = self.resolver.query('google.com', 'TXT')
59        result = self.loop.run_until_complete(f)
60        self.assertTrue(result)
61
62    def test_query_soa(self):
63        f = self.resolver.query('google.com', 'SOA')
64        result = self.loop.run_until_complete(f)
65        self.assertTrue(result)
66
67    def test_query_srv(self):
68        f = self.resolver.query('_xmpp-server._tcp.jabber.org', 'SRV')
69        result = self.loop.run_until_complete(f)
70        self.assertTrue(result)
71
72    def test_query_naptr(self):
73        f = self.resolver.query('sip2sip.info', 'NAPTR')
74        result = self.loop.run_until_complete(f)
75        self.assertTrue(result)
76
77    def test_query_ptr(self):
78        ip = '8.8.8.8'
79        f = self.resolver.query(pycares.reverse_address(ip), 'PTR')
80        result = self.loop.run_until_complete(f)
81        self.assertTrue(result)
82
83    def test_query_bad_type(self):
84        self.assertRaises(ValueError, self.resolver.query, 'google.com', 'XXX')
85
86    def test_query_timeout(self):
87        self.resolver = aiodns.DNSResolver(timeout=0.1, loop=self.loop)
88        self.resolver.nameservers = ['1.2.3.4']
89        f = self.resolver.query('google.com', 'A')
90        try:
91            self.loop.run_until_complete(f)
92        except aiodns.error.DNSError as e:
93            self.assertEqual(e.args[0], aiodns.error.ARES_ETIMEOUT)
94
95    def test_query_cancel(self):
96        f = self.resolver.query('google.com', 'A')
97        self.resolver.cancel()
98        try:
99            self.loop.run_until_complete(f)
100        except aiodns.error.DNSError as e:
101            self.assertEqual(e.args[0], aiodns.error.ARES_ECANCELLED)
102
103#    def test_future_cancel(self):
104#        # TODO: write this in such a way it also works with trollius
105#        f = self.resolver.query('google.com', 'A')
106#        f.cancel()
107#        def coro():
108#            yield from asyncio.sleep(0.1, loop=self.loop)
109#            yield from f
110#        try:
111#            self.loop.run_until_complete(coro())
112#        except asyncio.CancelledError as e:
113#            self.assertTrue(e)
114
115    def test_query_twice(self):
116        if sys.version[:3] >= '3.3':
117            exec('''if 1:
118            @asyncio.coroutine
119            def coro(self, host, qtype, n=2):
120                for i in range(n):
121                    result = yield from self.resolver.query(host, qtype)
122                    self.assertTrue(result)
123            ''')
124
125        else:
126            exec('''if 1:
127            @asyncio.coroutine
128            def coro(self, host, qtype, n=2):
129                for i in range(n):
130                    result = yield asyncio.From(self.resolver.query(host, qtype))
131                    self.assertTrue(result)
132            ''')
133
134        self.loop.run_until_complete(locals()['coro'](self, 'gmail.com', 'MX'))
135
136    def test_gethostbyname(self):
137        f = self.resolver.gethostbyname("google.com", socket.AF_INET)
138        result = self.loop.run_until_complete(f)
139        self.assertTrue(result)
140
141    def test_gethostbyname_ipv6(self):
142        f = self.resolver.gethostbyname("ipv6.google.com", socket.AF_INET6)
143        result = self.loop.run_until_complete(f)
144        self.assertTrue(result)
145
146    def test_gethostbyname_bad_family(self):
147        f = self.resolver.gethostbyname("ipv6.google.com", -1)
148        with self.assertRaises(aiodns.error.DNSError):
149            self.loop.run_until_complete(f)
150
151
152if __name__ == '__main__':
153    unittest.main(verbosity=2)
154
155