1from __future__ import absolute_import
2from __future__ import unicode_literals
3
4import ssl
5import json
6import time
7import socket
8import logging
9
10from ._compat import string_types, to_bytes, struct_l
11from .version import __version__
12
13try:
14    from .snappy_socket import SnappySocket, SnappyEncoder
15except ImportError:
16    SnappyEncoder = SnappySocket = None  # pyflakes.ignore
17
18import tornado.iostream
19import tornado.ioloop
20
21from nsq import event, protocol
22from .deflate_socket import DeflateSocket, DeflateEncoder
23
24logger = logging.getLogger(__name__)
25
26
27# states
28INIT = 'INIT'
29DISCONNECTED = 'DISCONNECTED'
30CONNECTING = 'CONNECTING'
31CONNECTED = 'CONNECTED'
32
33
34DEFAULT_USER_AGENT = 'pynsq/%s' % __version__
35
36
37class DefaultEncoder(object):
38
39    @staticmethod
40    def encode(data):
41        return data
42
43
44class AsyncConn(event.EventedMixin):
45    """
46    Low level object representing a TCP connection to nsqd.
47
48    When a message on this connection is requeued and the requeue delay
49    has not been specified, it calculates the delay automatically by an
50    increasing multiple of ``requeue_delay``.
51
52    Generates the following events that can be listened to with
53    :meth:`nsq.AsyncConn.on`:
54
55     * ``connect``
56     * ``close``
57     * ``error``
58     * ``identify``
59     * ``identify_response``
60     * ``auth``
61     * ``auth_response``
62     * ``heartbeat``
63     * ``ready``
64     * ``message``
65     * ``response``
66     * ``backoff``
67     * ``resume``
68
69    :param host: the host to connect to
70
71    :param port: the post to connect to
72
73    :param timeout: the timeout for read/write operations (in seconds)
74
75    :param heartbeat_interval: the amount of time (in seconds) to negotiate
76        with the connected producers to send heartbeats (requires nsqd 0.2.19+)
77
78    :param requeue_delay: the base multiple used when calculating requeue delay
79        (multiplied by # of attempts)
80
81    :param tls_v1: enable TLS v1 encryption (requires nsqd 0.2.22+)
82
83    :param tls_options: dictionary of options to pass to `ssl.wrap_socket()
84        <http://docs.python.org/2/library/ssl.html#ssl.wrap_socket>`_ as
85        ``**kwargs``
86
87    :param snappy: enable Snappy stream compression (requires nsqd 0.2.23+)
88
89    :param deflate: enable deflate stream compression (requires nsqd 0.2.23+)
90
91    :param deflate_level: configure the deflate compression level for this
92        connection (requires nsqd 0.2.23+)
93
94    :param output_buffer_size: size of the buffer (in bytes) used by nsqd
95        for buffering writes to this connection
96
97    :param output_buffer_timeout: timeout (in ms) used by nsqd before
98        flushing buffered writes (set to 0 to disable).  **Warning**:
99        configuring clients with an extremely low (``< 25ms``)
100        ``output_buffer_timeout`` has a significant effect on ``nsqd``
101        CPU usage (particularly with ``> 50`` clients connected).
102
103    :param sample_rate: take only a sample of the messages being sent
104        to the client. Not setting this or setting it to 0 will ensure
105        you get all the messages destined for the client.
106        Sample rate can be greater than 0 or less than 100 and the client
107        will receive that percentage of the message traffic.
108        (requires nsqd 0.2.25+)
109
110    :param user_agent: a string identifying the agent for this client
111        in the spirit of HTTP (default: ``<client_library_name>/<version>``)
112        (requires nsqd 0.2.25+)
113
114    :param auth_secret: a string passed when using nsq auth
115        (requires nsqd 1.0+)
116
117    :param msg_timeout: the amount of time (in seconds) that nsqd will wait
118        before considering messages that have been delivered to this
119        consumer timed out (requires nsqd 0.2.28+)
120
121    :param hostname: a string identifying the host where this client runs
122        (default: ``<hostname>``)
123    """
124    def __init__(
125            self,
126            host,
127            port,
128            timeout=1.0,
129            heartbeat_interval=30,
130            requeue_delay=90,
131            tls_v1=False,
132            tls_options=None,
133            snappy=False,
134            deflate=False,
135            deflate_level=6,
136            user_agent=DEFAULT_USER_AGENT,
137            output_buffer_size=16 * 1024,
138            output_buffer_timeout=250,
139            sample_rate=0,
140            auth_secret=None,
141            msg_timeout=None,
142            hostname=None):
143        assert isinstance(host, string_types)
144        assert isinstance(port, int)
145        assert isinstance(timeout, float)
146        assert isinstance(tls_options, (dict, None.__class__))
147        assert isinstance(deflate_level, int)
148        assert isinstance(heartbeat_interval, int) and heartbeat_interval >= 1
149        assert isinstance(requeue_delay, int) and requeue_delay >= 0
150        assert isinstance(output_buffer_size, int) and output_buffer_size >= 0
151        assert isinstance(output_buffer_timeout, int) and output_buffer_timeout >= 0
152        assert isinstance(sample_rate, int) and sample_rate >= 0 and sample_rate < 100
153        assert msg_timeout is None or (isinstance(msg_timeout, (float, int)) and msg_timeout > 0)
154        # auth_secret validated by to_bytes() below
155
156        self.state = INIT
157        self.host = host
158        self.port = port
159        self.timeout = timeout
160        self.last_recv_timestamp = time.time()
161        self.last_msg_timestamp = time.time()
162        self.in_flight = 0
163        self.rdy = 0
164        self.rdy_timeout = None
165        # for backwards compatibility when interacting with older nsqd
166        # (pre 0.2.20), default this to their hard-coded max
167        self.max_rdy_count = 2500
168        self.tls_v1 = tls_v1
169        self.tls_options = tls_options
170        self.snappy = snappy
171        self.deflate = deflate
172        self.deflate_level = deflate_level
173        self.hostname = hostname
174        if self.hostname is None:
175            self.hostname = socket.gethostname()
176        self.short_hostname = self.hostname.split('.')[0]
177        self.heartbeat_interval = heartbeat_interval * 1000
178        self.msg_timeout = int(msg_timeout * 1000) if msg_timeout else None
179        self.requeue_delay = requeue_delay
180
181        self.output_buffer_size = output_buffer_size
182        self.output_buffer_timeout = output_buffer_timeout
183        self.sample_rate = sample_rate
184        self.user_agent = user_agent
185
186        self._authentication_required = False  # tracking server auth state
187        self.auth_secret = to_bytes(auth_secret) if auth_secret else None
188
189        self.socket = None
190        self.stream = None
191        self._features_to_enable = []
192
193        self.last_rdy = 0
194        self.rdy = 0
195
196        self.callback_queue = []
197        self.encoder = DefaultEncoder()
198
199        super(AsyncConn, self).__init__()
200
201    @property
202    def id(self):
203        return str(self)
204
205    def __str__(self):
206        return self.host + ':' + str(self.port)
207
208    def connected(self):
209        return self.state == CONNECTED
210
211    def connecting(self):
212        return self.state == CONNECTING
213
214    def closed(self):
215        return self.state in (INIT, DISCONNECTED)
216
217    def connect(self):
218        if not self.closed():
219            return
220
221        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
222        self.socket.settimeout(self.timeout)
223        self.socket.setblocking(0)
224
225        self.stream = tornado.iostream.IOStream(self.socket)
226        self.stream.set_close_callback(self._socket_close)
227        self.stream.set_nodelay(True)
228
229        self.state = CONNECTING
230        self.on(event.CONNECT, self._on_connect)
231        self.on(event.DATA, self._on_data)
232
233        self.stream.connect((self.host, self.port), self._connect_callback)
234
235    def _connect_callback(self):
236        self.state = CONNECTED
237        self.stream.write(protocol.MAGIC_V2)
238        self._start_read()
239        self.trigger(event.CONNECT, conn=self)
240
241    def _read_bytes(self, size, callback):
242        try:
243            self.stream.read_bytes(size, callback)
244        except IOError:
245            self.close()
246            self.trigger(
247                event.ERROR,
248                conn=self,
249                error=protocol.ConnectionClosedError('Stream is closed'),
250            )
251
252    def _start_read(self):
253        if self.stream is None:
254            return  # IOStream.start_tls() invalidates stream, will call again when ready
255        self._read_bytes(4, self._read_size)
256
257    def _socket_close(self):
258        self.state = DISCONNECTED
259        self.trigger(event.CLOSE, conn=self)
260
261    def close(self):
262        self.stream.close()
263
264    def _read_size(self, data):
265        try:
266            size = struct_l.unpack(data)[0]
267        except Exception:
268            self.close()
269            self.trigger(
270                event.ERROR,
271                conn=self,
272                error=protocol.IntegrityError('failed to unpack size'),
273            )
274            return
275        self._read_bytes(size, self._read_body)
276
277    def _read_body(self, data):
278        try:
279            self.trigger(event.DATA, conn=self, data=data)
280        except Exception:
281            logger.exception('uncaught exception in data event')
282        self._start_read()
283
284    def send(self, data):
285        self.stream.write(self.encoder.encode(data))
286
287    def upgrade_to_tls(self, options=None):
288        # in order to upgrade to TLS we need to *replace* the IOStream...
289        opts = {
290            'cert_reqs': ssl.CERT_REQUIRED,
291            'ssl_version': ssl.PROTOCOL_TLSv1_2
292        }
293        opts.update(options or {})
294
295        fut = self.stream.start_tls(False, ssl_options=opts, server_hostname=self.host)
296        self.stream = None
297
298        def finish_upgrade_tls(fut):
299            try:
300                self.stream = fut.result()
301                self.socket = self.stream.socket
302                self._start_read()
303            except Exception as e:
304                # skip self.close() because no stream
305                self.trigger(
306                    event.ERROR,
307                    conn=self,
308                    error=protocol.SendError('failed to upgrade to TLS', e),
309                )
310
311        tornado.ioloop.IOLoop.current().add_future(fut, finish_upgrade_tls)
312
313    def upgrade_to_snappy(self):
314        assert SnappySocket, 'snappy requires the python-snappy package'
315
316        # in order to upgrade to Snappy we need to use whatever IOStream
317        # is currently in place (normal or SSL)...
318        #
319        # first read any compressed bytes the existing IOStream might have
320        # already buffered and use that to bootstrap the SnappySocket, then
321        # monkey patch the existing IOStream by replacing its socket
322        # with a wrapper that will automagically handle compression.
323        existing_data = self.stream._consume(self.stream._read_buffer_size)
324        self.socket = SnappySocket(self.socket)
325        self.socket.bootstrap(existing_data)
326        self.stream.socket = self.socket
327        self.encoder = SnappyEncoder()
328
329    def upgrade_to_deflate(self):
330        # in order to upgrade to DEFLATE we need to use whatever IOStream
331        # is currently in place (normal or SSL)...
332        #
333        # first read any compressed bytes the existing IOStream might have
334        # already buffered and use that to bootstrap the DeflateSocket, then
335        # monkey patch the existing IOStream by replacing its socket
336        # with a wrapper that will automagically handle compression.
337        existing_data = self.stream._consume(self.stream._read_buffer_size)
338        self.socket = DeflateSocket(self.socket, self.deflate_level)
339        self.socket.bootstrap(existing_data)
340        self.stream.socket = self.socket
341        self.encoder = DeflateEncoder(level=self.deflate_level)
342
343    def send_rdy(self, value):
344        try:
345            self.send(protocol.ready(value))
346        except Exception as e:
347            self.close()
348            self.trigger(
349                event.ERROR,
350                conn=self,
351                error=protocol.SendError('failed to send RDY %d' % value, e),
352            )
353            return False
354        self.last_rdy = value
355        self.rdy = value
356        return True
357
358    def _on_connect(self, **kwargs):
359        identify_data = {
360            'short_id': self.short_hostname,  # TODO remove when deprecating pre 1.0 support
361            'long_id': self.hostname,  # TODO remove when deprecating pre 1.0 support
362            'client_id': self.short_hostname,
363            'hostname': self.hostname,
364            'heartbeat_interval': self.heartbeat_interval,
365            'feature_negotiation': True,
366            'tls_v1': self.tls_v1,
367            'snappy': self.snappy,
368            'deflate': self.deflate,
369            'deflate_level': self.deflate_level,
370            'output_buffer_timeout': self.output_buffer_timeout,
371            'output_buffer_size': self.output_buffer_size,
372            'sample_rate': self.sample_rate,
373            'user_agent': self.user_agent
374        }
375        if self.msg_timeout:
376            identify_data['msg_timeout'] = self.msg_timeout
377        self.trigger(event.IDENTIFY, conn=self, data=identify_data)
378        self.on(event.RESPONSE, self._on_identify_response)
379        try:
380            self.send(protocol.identify(identify_data))
381        except Exception as e:
382            self.close()
383            self.trigger(
384                event.ERROR,
385                conn=self,
386                error=protocol.SendError('failed to bootstrap connection', e),
387            )
388
389    def _on_identify_response(self, data, **kwargs):
390        self.off(event.RESPONSE, self._on_identify_response)
391
392        if data == b'OK':
393            logger.warning('nsqd version does not support feature netgotiation')
394            return self.trigger(event.READY, conn=self)
395
396        try:
397            data = json.loads(data.decode('utf-8'))
398        except ValueError:
399            self.close()
400            self.trigger(
401                event.ERROR,
402                conn=self,
403                error=protocol.IntegrityError(
404                    'failed to parse IDENTIFY response JSON from nsqd - %r' %
405                    data
406                ),
407            )
408            return
409
410        self.trigger(event.IDENTIFY_RESPONSE, conn=self, data=data)
411
412        if self.tls_v1 and data.get('tls_v1'):
413            self._features_to_enable.append('tls_v1')
414        if self.snappy and data.get('snappy'):
415            self._features_to_enable.append('snappy')
416        if self.deflate and data.get('deflate'):
417            self._features_to_enable.append('deflate')
418
419        if data.get('auth_required'):
420            self._authentication_required = True
421
422        if data.get('max_rdy_count'):
423            self.max_rdy_count = data.get('max_rdy_count')
424        else:
425            # for backwards compatibility when interacting with older nsqd
426            # (pre 0.2.20), default this to their hard-coded max
427            logger.warn('setting max_rdy_count to default value of 2500')
428            self.max_rdy_count = 2500
429
430        self.on(event.RESPONSE, self._on_response_continue)
431        self._on_response_continue(conn=self, data=None)
432
433    def _on_response_continue(self, data, **kwargs):
434        if self._features_to_enable:
435            feature = self._features_to_enable.pop(0)
436            if feature == 'tls_v1':
437                self.upgrade_to_tls(self.tls_options)
438            elif feature == 'snappy':
439                self.upgrade_to_snappy()
440            elif feature == 'deflate':
441                self.upgrade_to_deflate()
442            # the server will 'OK' after these connection upgrades triggering another response
443            return
444
445        self.off(event.RESPONSE, self._on_response_continue)
446        if self.auth_secret and self._authentication_required:
447            self.on(event.RESPONSE, self._on_auth_response)
448            self.trigger(event.AUTH, conn=self, data=self.auth_secret)
449            try:
450                self.send(protocol.auth(self.auth_secret))
451            except Exception as e:
452                self.close()
453                self.trigger(
454                    event.ERROR,
455                    conn=self,
456                    error=protocol.SendError('Error sending AUTH', e),
457                )
458            return
459        self.trigger(event.READY, conn=self)
460
461    def _on_auth_response(self, data, **kwargs):
462        try:
463            data = json.loads(data.decode('utf-8'))
464        except ValueError:
465            self.close()
466            self.trigger(
467                event.ERROR,
468                conn=self,
469                error=protocol.IntegrityError(
470                    'failed to parse AUTH response JSON from nsqd - %r' % data
471                ),
472            )
473            return
474
475        self.off(event.RESPONSE, self._on_auth_response)
476        self.trigger(event.AUTH_RESPONSE, conn=self, data=data)
477        return self.trigger(event.READY, conn=self)
478
479    def _on_data(self, data, **kwargs):
480        self.last_recv_timestamp = time.time()
481        frame, data = protocol.unpack_response(data)
482        if frame == protocol.FRAME_TYPE_MESSAGE:
483            self.last_msg_timestamp = time.time()
484            self.in_flight += 1
485
486            message = protocol.decode_message(data)
487            message.on(event.FINISH, self._on_message_finish)
488            message.on(event.REQUEUE, self._on_message_requeue)
489            message.on(event.TOUCH, self._on_message_touch)
490
491            self.trigger(event.MESSAGE, conn=self, message=message)
492        elif frame == protocol.FRAME_TYPE_RESPONSE and data == b'_heartbeat_':
493            self.send(protocol.nop())
494            self.trigger(event.HEARTBEAT, conn=self)
495        elif frame == protocol.FRAME_TYPE_RESPONSE:
496            self.trigger(event.RESPONSE, conn=self, data=data)
497        elif frame == protocol.FRAME_TYPE_ERROR:
498            self.trigger(event.ERROR, conn=self, error=protocol.Error(data))
499
500    def _on_message_requeue(self, message, backoff=True, time_ms=-1, **kwargs):
501        if backoff:
502            self.trigger(event.BACKOFF, conn=self)
503        else:
504            self.trigger(event.CONTINUE, conn=self)
505
506        self.in_flight -= 1
507        try:
508            time_ms = self.requeue_delay * message.attempts * 1000 if time_ms < 0 else time_ms
509            self.send(protocol.requeue(message.id, time_ms))
510        except Exception as e:
511            self.close()
512            self.trigger(event.ERROR, conn=self, error=protocol.SendError(
513                'failed to send REQ %s @ %d' % (message.id, time_ms), e))
514
515    def _on_message_finish(self, message, **kwargs):
516        self.trigger(event.RESUME, conn=self)
517
518        self.in_flight -= 1
519        try:
520            self.send(protocol.finish(message.id))
521        except Exception as e:
522            self.close()
523            self.trigger(
524                event.ERROR,
525                conn=self,
526                error=protocol.SendError('failed to send FIN %s' % message.id, e),
527            )
528
529    def _on_message_touch(self, message, **kwargs):
530        try:
531            self.send(protocol.touch(message.id))
532        except Exception as e:
533            self.close()
534            self.trigger(
535                event.ERROR,
536                conn=self,
537                error=protocol.SendError('failed to send TOUCH %s' % message.id, e),
538            )
539