1from __future__ import absolute_import
2from __future__ import print_function
3from __future__ import unicode_literals
4
5import os
6import sys
7import signal
8import subprocess
9import time
10import ssl
11
12import tornado.httpclient
13import tornado.testing
14import tornado.web
15import tornado.httpserver
16
17# shunt '..' into sys.path since we are in a 'tests' subdirectory
18base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
19if base_dir not in sys.path:
20    sys.path.insert(0, base_dir)
21
22from nsq import protocol
23from nsq.conn import AsyncConn
24from nsq.reader import Reader
25from nsq.deflate_socket import DeflateSocket
26from nsq.snappy_socket import SnappySocket
27
28
29class IntegrationBase(tornado.testing.AsyncTestCase):
30    nsqd_command = []
31    nsqlookupd_command = []
32
33    def setUp(self):
34        super(IntegrationBase, self).setUp()
35        if not hasattr(self, 'processes'):
36            self.processes = []
37
38        if self.nsqlookupd_command:
39            proc = subprocess.Popen(self.nsqlookupd_command)
40            self.processes.append(proc)
41            self.wait_ping('http://127.0.0.1:4161/ping')
42        if self.nsqd_command:
43            proc = subprocess.Popen(self.nsqd_command)
44            self.processes.append(proc)
45            self.wait_ping('http://127.0.0.1:4151/ping')
46
47    def tearDown(self):
48        super(IntegrationBase, self).tearDown()
49        for proc in self.processes:
50            os.kill(proc.pid, signal.SIGKILL)
51            proc.wait()
52
53    def wait_ping(self, endpoint):
54        start = time.time()
55        http = tornado.httpclient.HTTPClient()
56        while True:
57            try:
58                resp = http.fetch(endpoint)
59                if resp.body == b'OK':
60                    break
61                continue
62            except Exception:
63                if time.time() - start > 5:
64                    raise
65                time.sleep(0.1)
66                continue
67
68    def _send_messages(self, topic, count, body):
69        c = AsyncConn('127.0.0.1', 4150)
70        c.connect()
71
72        def _on_ready(*args, **kwargs):
73            for i in range(count):
74                c.send(protocol.pub(topic, body))
75
76        c.on('ready', _on_ready)
77
78
79class ReaderIntegrationTest(IntegrationBase):
80    identify_options = {
81        'user_agent': 'sup',
82        'snappy': True,
83        'tls_v1': True,
84        'tls_options': {'cert_reqs': ssl.CERT_NONE},
85        'heartbeat_interval': 10,
86        'output_buffer_size': 4096,
87        'output_buffer_timeout': 50
88    }
89
90    nsqd_command = ['nsqd', '--verbose', '--snappy',
91                    '--tls-key=%s/tests/key.pem' % base_dir,
92                    '--tls-cert=%s/tests/cert.pem' % base_dir]
93
94    def test_bad_reader_arguments(self):
95        topic = 'test_reader_msgs_%s' % time.time()
96        bad_options = dict(self.identify_options)
97        bad_options.update(dict(foo=10))
98
99        def handler(msg):
100            return None
101
102        self.assertRaises(
103            AssertionError,
104            Reader,
105            nsqd_tcp_addresses=['127.0.0.1:4150'],
106            topic=topic, channel='ch',
107            message_handler=handler, max_in_flight=100,
108            **bad_options)
109
110    def test_conn_identify(self):
111        c = AsyncConn('127.0.0.1', 4150)
112        c.on('identify_response', self.stop)
113        c.connect()
114        response = self.wait()
115        print(response)
116        assert response['conn'] is c
117        assert isinstance(response['data'], dict)
118
119    def test_conn_identify_options(self):
120        c = AsyncConn('127.0.0.1', 4150, **self.identify_options)
121        c.on('identify_response', self.stop)
122        c.connect()
123        response = self.wait()
124        print(response)
125        assert response['conn'] is c
126        assert isinstance(response['data'], dict)
127        assert response['data']['snappy'] is True
128        assert response['data']['tls_v1'] is True
129
130    def test_conn_socket_upgrade(self):
131        c = AsyncConn('127.0.0.1', 4150, **self.identify_options)
132        c.on('ready', self.stop)
133        c.connect()
134        self.wait()
135        assert isinstance(c.socket, SnappySocket)
136        assert isinstance(c.socket._socket, ssl.SSLSocket)
137
138    def test_conn_subscribe(self):
139        topic = 'test_conn_suscribe_%s' % time.time()
140        c = AsyncConn('127.0.0.1', 4150, **self.identify_options)
141
142        def _on_ready(*args, **kwargs):
143            c.on('response', self.stop)
144            c.send(protocol.subscribe(topic, 'ch'))
145
146        c.on('ready', _on_ready)
147        c.connect()
148        response = self.wait()
149        print(response)
150        assert response['conn'] is c
151        assert response['data'] == b'OK'
152
153    def test_conn_messages(self):
154        self.msg_count = 0
155
156        topic = 'test_conn_suscribe_%s' % time.time()
157        self._send_messages(topic, 5, b'sup')
158
159        c = AsyncConn('127.0.0.1', 4150, **self.identify_options)
160
161        def _on_message(*args, **kwargs):
162            self.msg_count += 1
163            if c.in_flight == 5:
164                self.stop()
165
166        def _on_ready(*args, **kwargs):
167            c.on('message', _on_message)
168            c.send(protocol.subscribe(topic, 'ch'))
169            c.send_rdy(5)
170
171        c.on('ready', _on_ready)
172        c.connect()
173
174        self.wait()
175        assert self.msg_count == 5
176
177    def test_reader_messages(self):
178        self.msg_count = 0
179        num_messages = 500
180
181        topic = 'test_reader_msgs_%s' % time.time()
182        self._send_messages(topic, num_messages, b'sup')
183
184        def handler(msg):
185            assert msg.body == b'sup'
186            self.msg_count += 1
187            if self.msg_count >= num_messages:
188                self.stop()
189            return True
190
191        r = Reader(nsqd_tcp_addresses=['127.0.0.1:4150'], topic=topic, channel='ch',
192                   message_handler=handler, max_in_flight=100,
193                   **self.identify_options)
194
195        self.wait()
196        r.close()
197
198    def test_reader_heartbeat(self):
199        this = self
200        this.count = 0
201
202        def handler(msg):
203            return True
204
205        class HeartbeatReader(Reader):
206            def heartbeat(self, conn):
207                this.count += 1
208                if this.count == 2:
209                    this.stop()
210
211        topic = 'test_reader_hb_%s' % time.time()
212        HeartbeatReader(nsqd_tcp_addresses=['127.0.0.1:4150'], topic=topic, channel='ch',
213                        message_handler=handler, max_in_flight=100,
214                        heartbeat_interval=1)
215        self.wait()
216
217
218class DeflateReaderIntegrationTest(IntegrationBase):
219    identify_options = {
220        'user_agent': 'sup',
221        'deflate': True,
222        'deflate_level': 6,
223        'tls_options': {'cert_reqs': ssl.CERT_NONE},
224        'heartbeat_interval': 10,
225        'output_buffer_size': 4096,
226        'output_buffer_timeout': 50
227    }
228
229    nsqd_command = ['nsqd', '--verbose', '--deflate',
230                    '--tls-key=%s/tests/key.pem' % base_dir,
231                    '--tls-cert=%s/tests/cert.pem' % base_dir]
232
233    def test_conn_identify_options(self):
234        c = AsyncConn('127.0.0.1', 4150, **self.identify_options)
235        c.on('identify_response', self.stop)
236        c.connect()
237        response = self.wait()
238        print(response)
239        assert response['conn'] is c
240        assert isinstance(response['data'], dict)
241        assert response['data']['deflate'] is True
242
243    def test_conn_socket_upgrade(self):
244        c = AsyncConn('127.0.0.1', 4150, **self.identify_options)
245        c.on('ready', self.stop)
246        c.connect()
247        self.wait()
248        assert isinstance(c.socket, DeflateSocket)
249
250    def test_reader_messages(self):
251        self.msg_count = 0
252        num_messages = 300
253        topic = 'test_reader_msgs_%s' % time.time()
254        self._send_messages(topic, num_messages, b'sup')
255
256        def handler(msg):
257            assert msg.body == b'sup'
258            self.msg_count += 1
259            if self.msg_count >= num_messages:
260                self.stop()
261            return True
262
263        r = Reader(nsqd_tcp_addresses=['127.0.0.1:4150'],
264                   topic=topic, channel='ch',
265                   message_handler=handler, max_in_flight=100,
266                   **self.identify_options)
267        self.wait()
268        r.close()
269
270
271class LookupdReaderIntegrationTest(IntegrationBase):
272    nsqd_command = ['nsqd', '--verbose', '-lookupd-tcp-address=127.0.0.1:4160']
273    nsqlookupd_command = ['nsqlookupd', '--verbose', '-broadcast-address=127.0.0.1']
274
275    identify_options = {
276        'user_agent': 'sup',
277        'heartbeat_interval': 10,
278        'output_buffer_size': 4096,
279        'output_buffer_timeout': 50
280    }
281
282    def test_reader_messages(self):
283        self.msg_count = 0
284        num_messages = 1
285
286        topic = 'test_lookupd_msgs_%s' % int(time.time())
287        self._send_messages(topic, num_messages, b'sup')
288
289        def handler(msg):
290            assert msg.body == b'sup'
291            self.msg_count += 1
292            if self.msg_count >= num_messages:
293                self.stop()
294            return True
295
296        r = Reader(lookupd_http_addresses=['http://127.0.0.1:4161'],
297                   topic=topic,
298                   channel='ch',
299                   message_handler=handler,
300                   lookupd_poll_interval=1,
301                   lookupd_poll_jitter=0.01,
302                   max_in_flight=100,
303                   **self.identify_options)
304
305        self.wait()
306        r.close()
307
308
309class AuthHandler(tornado.web.RequestHandler):
310    def get(self):
311        if self.get_argument('secret') == 'opensesame':
312            self.finish({
313                'ttl': 30,
314                'identity': 'username',
315                'authorizations': [{
316                    'permissions': ['subscribe'],
317                    'topic': 'authtopic',
318                    'channels': ['ch'],
319                }]
320            })
321        else:
322            self.set_status(403)
323
324
325class ReaderAuthIntegrationTest(IntegrationBase):
326    identify_options = {
327        'user_agent': 'supauth',
328        'heartbeat_interval': 10,
329        'output_buffer_size': 4096,
330        'output_buffer_timeout': 50,
331        'auth_secret': "opensesame",
332    }
333
334    def setUp(self):
335        auth_sock, auth_port = tornado.testing.bind_unused_port()
336        self.auth_sock = auth_sock
337        self.nsqd_command = [
338            'nsqd', '--verbose', '--auth-http-address=127.0.0.1:%d' % auth_port
339        ]
340        super(ReaderAuthIntegrationTest, self).setUp()
341
342    def tearDown(self):
343        super(ReaderAuthIntegrationTest, self).tearDown()
344        self.auth_sock.close()
345
346    def test_conn_identify(self):
347        auth_app = tornado.web.Application([("/auth", AuthHandler)])
348        auth_srv = tornado.httpserver.HTTPServer(auth_app)
349        auth_srv.add_socket(self.auth_sock)
350
351        c = AsyncConn('127.0.0.1', 4150, **self.identify_options)
352
353        def _on_ready(*args, **kwargs):
354            c.on('response', self.stop)
355            c.send(protocol.subscribe('authtopic', 'ch'))
356
357        c.on('ready', _on_ready)
358        c.connect()
359        response = self.wait()
360        auth_srv.stop()
361        print(response)
362        assert response['conn'] is c
363        assert response['data'] == b'OK'
364