1# Copyright (C) 2015 Cisco Systems, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
14
15import logging
16import threading
17
18import confluent_kafka
19from confluent_kafka import KafkaException
20from oslo_serialization import jsonutils
21from oslo_utils import eventletutils
22from oslo_utils import importutils
23
24from oslo_messaging._drivers import base
25from oslo_messaging._drivers import common as driver_common
26from oslo_messaging._drivers.kafka_driver import kafka_options
27
28if eventletutils.EVENTLET_AVAILABLE:
29    tpool = importutils.try_import('eventlet.tpool')
30
31LOG = logging.getLogger(__name__)
32
33
34def unpack_message(msg):
35    """Unpack context and msg."""
36    context = {}
37    message = None
38    msg = jsonutils.loads(msg)
39    message = driver_common.deserialize_msg(msg)
40    context = message['_context']
41    del message['_context']
42    return context, message
43
44
45def pack_message(ctxt, msg):
46    """Pack context into msg."""
47    if isinstance(ctxt, dict):
48        context_d = ctxt
49    else:
50        context_d = ctxt.to_dict()
51    msg['_context'] = context_d
52
53    msg = driver_common.serialize_msg(msg)
54
55    return msg
56
57
58def concat(sep, items):
59    return sep.join(filter(bool, items))
60
61
62def target_to_topic(target, priority=None, vhost=None):
63    """Convert target into topic string
64
65    :param target: Message destination target
66    :type target: oslo_messaging.Target
67    :param priority: Notification priority
68    :type priority: string
69    :param priority: Notification vhost
70    :type priority: string
71    """
72    return concat(".", [target.topic, priority, vhost])
73
74
75class ConsumerTimeout(KafkaException):
76    pass
77
78
79class AssignedPartition(object):
80    """This class is used by the ConsumerConnection to track the
81    assigned partitions.
82    """
83    def __init__(self, topic, partition):
84        super(AssignedPartition, self).__init__()
85        self.topic = topic
86        self.partition = partition
87        self.skey = '%s %d' % (self.topic, self.partition)
88
89    def to_dict(self):
90        return {'topic': self.topic, 'partition': self.partition}
91
92
93class Connection(object):
94    """This is the base class for consumer and producer connections for
95    transport attributes.
96    """
97
98    def __init__(self, conf, url):
99
100        self.driver_conf = conf.oslo_messaging_kafka
101        self.security_protocol = self.driver_conf.security_protocol
102        self.sasl_mechanism = self.driver_conf.sasl_mechanism
103        self.ssl_cafile = self.driver_conf.ssl_cafile
104        self.ssl_client_cert_file = self.driver_conf.ssl_client_cert_file
105        self.ssl_client_key_file = self.driver_conf.ssl_client_key_file
106        self.ssl_client_key_password = self.driver_conf.ssl_client_key_password
107        self.url = url
108        self.virtual_host = url.virtual_host
109        self._parse_url()
110
111    def _parse_url(self):
112        self.hostaddrs = []
113        self.username = None
114        self.password = None
115
116        for host in self.url.hosts:
117            # NOTE(ansmith): connections and failover are transparently
118            # managed by the client library. Credentials will be
119            # selectd from first host encountered in transport_url
120            if self.username is None:
121                self.username = host.username
122                self.password = host.password
123            else:
124                if self.username != host.username:
125                    LOG.warning("Different transport usernames detected")
126
127            if host.hostname:
128                self.hostaddrs.append("%s:%s" % (host.hostname, host.port))
129
130    def reset(self):
131        """Reset a connection so it can be used again."""
132        pass
133
134
135class ConsumerConnection(Connection):
136    """This is the class for kafka topic/assigned partition consumer
137    """
138    def __init__(self, conf, url):
139
140        super(ConsumerConnection, self).__init__(conf, url)
141        self.consumer = None
142        self.consumer_timeout = self.driver_conf.kafka_consumer_timeout
143        self.max_fetch_bytes = self.driver_conf.kafka_max_fetch_bytes
144        self.group_id = self.driver_conf.consumer_group
145        self.use_auto_commit = self.driver_conf.enable_auto_commit
146        self.max_poll_records = self.driver_conf.max_poll_records
147        self._consume_loop_stopped = False
148        self.assignment_dict = dict()
149
150    def find_assignment(self, topic, partition):
151        """Find and return existing assignment based on topic and partition"""
152        skey = '%s %d' % (topic, partition)
153        return self.assignment_dict.get(skey)
154
155    def on_assign(self, consumer, topic_partitions):
156        """Rebalance on_assign callback"""
157        assignment = [AssignedPartition(p.topic, p.partition)
158                      for p in topic_partitions]
159        self.assignment_dict = {a.skey: a for a in assignment}
160        for t in topic_partitions:
161            LOG.debug("Topic %s assigned to partition %d",
162                      t.topic, t.partition)
163
164    def on_revoke(self, consumer, topic_partitions):
165        """Rebalance on_revoke callback"""
166        self.assignment_dict = dict()
167        for t in topic_partitions:
168            LOG.debug("Topic %s revoked from partition %d",
169                      t.topic, t.partition)
170
171    def _poll_messages(self, timeout):
172        """Consume messages, callbacks and return list of messages"""
173        msglist = self.consumer.consume(self.max_poll_records,
174                                        timeout)
175
176        if ((len(self.assignment_dict) == 0) or (len(msglist) == 0)):
177            raise ConsumerTimeout()
178
179        messages = []
180        for message in msglist:
181            if message is None:
182                break
183            a = self.find_assignment(message.topic(), message.partition())
184            if a is None:
185                LOG.warning(("Message for %s received on unassigned "
186                             "partition %d"),
187                            message.topic(), message.partition())
188            else:
189                messages.append(message.value())
190
191        if not self.use_auto_commit:
192            self.consumer.commit(asynchronous=False)
193
194        return messages
195
196    def consume(self, timeout=None):
197        """Receive messages.
198
199        :param timeout: poll timeout in seconds
200        """
201
202        def _raise_timeout(exc):
203            raise driver_common.Timeout(str(exc))
204
205        timer = driver_common.DecayingTimer(duration=timeout)
206        timer.start()
207
208        poll_timeout = (self.consumer_timeout if timeout is None
209                        else min(timeout, self.consumer_timeout))
210
211        while True:
212            if self._consume_loop_stopped:
213                return
214            try:
215                if eventletutils.is_monkey_patched('thread'):
216                    return tpool.execute(self._poll_messages, poll_timeout)
217                return self._poll_messages(poll_timeout)
218            except ConsumerTimeout as exc:
219                poll_timeout = timer.check_return(
220                    _raise_timeout, exc, maximum=self.consumer_timeout)
221            except Exception:
222                LOG.exception("Failed to consume messages")
223                return
224
225    def stop_consuming(self):
226        self._consume_loop_stopped = True
227
228    def close(self):
229        if self.consumer:
230            self.consumer.close()
231            self.consumer = None
232
233    def declare_topic_consumer(self, topics, group=None):
234        conf = {
235            'bootstrap.servers': ",".join(self.hostaddrs),
236            'group.id': (group or self.group_id),
237            'enable.auto.commit': self.use_auto_commit,
238            'max.partition.fetch.bytes': self.max_fetch_bytes,
239            'security.protocol': self.security_protocol,
240            'sasl.mechanism': self.sasl_mechanism,
241            'sasl.username': self.username,
242            'sasl.password': self.password,
243            'ssl.ca.location': self.ssl_cafile,
244            'ssl.certificate.location': self.ssl_client_cert_file,
245            'ssl.key.location': self.ssl_client_key_file,
246            'ssl.key.password': self.ssl_client_key_password,
247            'enable.partition.eof': False,
248            'default.topic.config': {'auto.offset.reset': 'latest'}
249        }
250        LOG.debug("Subscribing to %s as %s", topics, (group or self.group_id))
251        self.consumer = confluent_kafka.Consumer(conf)
252        self.consumer.subscribe(topics,
253                                on_assign=self.on_assign,
254                                on_revoke=self.on_revoke)
255
256
257class ProducerConnection(Connection):
258
259    def __init__(self, conf, url):
260
261        super(ProducerConnection, self).__init__(conf, url)
262        self.batch_size = self.driver_conf.producer_batch_size
263        self.linger_ms = self.driver_conf.producer_batch_timeout * 1000
264        self.compression_codec = self.driver_conf.compression_codec
265        self.producer = None
266        self.producer_lock = threading.Lock()
267
268    def _produce_message(self, topic, message):
269        while True:
270            try:
271                self.producer.produce(topic, message)
272            except KafkaException as e:
273                LOG.error("Produce message failed: %s" % str(e))
274            except BufferError:
275                LOG.debug("Produce message queue full, waiting for deliveries")
276                self.producer.poll(0.5)
277                continue
278            break
279
280        self.producer.poll(0)
281
282    def notify_send(self, topic, ctxt, msg, retry):
283        """Send messages to Kafka broker.
284
285        :param topic: String of the topic
286        :param ctxt: context for the messages
287        :param msg: messages for publishing
288        :param retry: the number of retry
289        """
290        retry = retry if retry >= 0 else None
291        message = pack_message(ctxt, msg)
292        message = jsonutils.dumps(message).encode('utf-8')
293
294        try:
295            self._ensure_producer()
296            if eventletutils.is_monkey_patched('thread'):
297                return tpool.execute(self._produce_message, topic, message)
298            return self._produce_message(topic, message)
299        except Exception:
300            # NOTE(sileht): if something goes wrong close the producer
301            # connection
302            self._close_producer()
303            raise
304
305    def close(self):
306        self._close_producer()
307
308    def _close_producer(self):
309        with self.producer_lock:
310            if self.producer:
311                try:
312                    self.producer.flush()
313                except KafkaException:
314                    LOG.error("Flush error during producer close")
315                self.producer = None
316
317    def _ensure_producer(self):
318        if self.producer:
319            return
320        with self.producer_lock:
321            if self.producer:
322                return
323            conf = {
324                'bootstrap.servers': ",".join(self.hostaddrs),
325                'linger.ms': self.linger_ms,
326                'batch.num.messages': self.batch_size,
327                'compression.codec': self.compression_codec,
328                'security.protocol': self.security_protocol,
329                'sasl.mechanism': self.sasl_mechanism,
330                'sasl.username': self.username,
331                'sasl.password': self.password,
332                'ssl.ca.location': self.ssl_cafile,
333                'ssl.certificate.location': self.ssl_client_cert_file,
334                'ssl.key.location': self.ssl_client_key_file,
335                'ssl.key.password': self.ssl_client_key_password
336            }
337            self.producer = confluent_kafka.Producer(conf)
338
339
340class OsloKafkaMessage(base.RpcIncomingMessage):
341
342    def __init__(self, ctxt, message):
343        super(OsloKafkaMessage, self).__init__(ctxt, message)
344
345    def requeue(self):
346        LOG.warning("requeue is not supported")
347
348    def reply(self, reply=None, failure=None):
349        LOG.warning("reply is not supported")
350
351    def heartbeat(self):
352        LOG.warning("heartbeat is not supported")
353
354
355class KafkaListener(base.PollStyleListener):
356
357    def __init__(self, conn):
358        super(KafkaListener, self).__init__()
359        self._stopped = eventletutils.Event()
360        self.conn = conn
361        self.incoming_queue = []
362
363        # FIXME(sileht): We do a first poll to ensure we topics are created
364        # This is a workaround mainly for functional tests, in real life
365        # this is fine if topics are not created synchroneously
366        self.poll(5)
367
368    @base.batch_poll_helper
369    def poll(self, timeout=None):
370        while not self._stopped.is_set():
371            if self.incoming_queue:
372                return self.incoming_queue.pop(0)
373            try:
374                messages = self.conn.consume(timeout=timeout) or []
375                for message in messages:
376                    msg = OsloKafkaMessage(*unpack_message(message))
377                    self.incoming_queue.append(msg)
378            except driver_common.Timeout:
379                return None
380
381    def stop(self):
382        self._stopped.set()
383        self.conn.stop_consuming()
384
385    def cleanup(self):
386        self.conn.close()
387
388
389class KafkaDriver(base.BaseDriver):
390    """Kafka Driver
391
392    See :doc:`kafka` for details.
393    """
394
395    def __init__(self, conf, url, default_exchange=None,
396                 allowed_remote_exmods=None):
397        conf = kafka_options.register_opts(conf, url)
398        super(KafkaDriver, self).__init__(
399            conf, url, default_exchange, allowed_remote_exmods)
400
401        self.listeners = []
402        self.virtual_host = url.virtual_host
403        self.pconn = ProducerConnection(conf, url)
404
405    def cleanup(self):
406        self.pconn.close()
407        for c in self.listeners:
408            c.close()
409        self.listeners = []
410        LOG.info("Kafka messaging driver shutdown")
411
412    def send(self, target, ctxt, message, wait_for_reply=None, timeout=None,
413             call_monitor_timeout=None, retry=None, transport_options=None):
414        raise NotImplementedError(
415            'The RPC implementation for Kafka is not implemented')
416
417    def send_notification(self, target, ctxt, message, version, retry=None):
418        """Send notification to Kafka brokers
419
420        :param target: Message destination target
421        :type target: oslo_messaging.Target
422        :param ctxt: Message context
423        :type ctxt: dict
424        :param message: Message payload to pass
425        :type message: dict
426        :param version: Messaging API version (currently not used)
427        :type version: str
428        :param call_monitor_timeout: Maximum time the client will wait for the
429            call to complete before or receive a message heartbeat indicating
430            the remote side is still executing.
431        :type call_monitor_timeout: float
432        :param retry: an optional default kafka consumer retries configuration
433                      None means to retry forever
434                      0 means no retry
435                      N means N retries
436        :type retry: int
437        """
438        self.pconn.notify_send(target_to_topic(target,
439                                               vhost=self.virtual_host),
440                               ctxt, message, retry)
441
442    def listen(self, target, batch_size, batch_timeout):
443        raise NotImplementedError(
444            'The RPC implementation for Kafka is not implemented')
445
446    def listen_for_notifications(self, targets_and_priorities, pool,
447                                 batch_size, batch_timeout):
448        """Listen to a specified list of targets on Kafka brokers
449
450        :param targets_and_priorities: List of pairs (target, priority)
451                                       priority is not used for kafka driver
452                                       target.exchange_target.topic is used as
453                                       a kafka topic
454        :type targets_and_priorities: list
455        :param pool: consumer group of Kafka consumers
456        :type pool: string
457        """
458        conn = ConsumerConnection(self.conf, self._url)
459        topics = []
460        for target, priority in targets_and_priorities:
461            topics.append(target_to_topic(target, priority))
462
463        conn.declare_topic_consumer(topics, pool)
464
465        listener = KafkaListener(conn)
466        return base.PollStyleListenerAdapter(listener, batch_size,
467                                             batch_timeout)
468