1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.kafka010
19
20import java.{util => ju}
21import java.util.concurrent.{Executors, ThreadFactory}
22
23import scala.collection.JavaConverters._
24import scala.concurrent.{ExecutionContext, Future}
25import scala.concurrent.duration.Duration
26import scala.util.control.NonFatal
27
28import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer}
29import org.apache.kafka.common.TopicPartition
30
31import org.apache.spark.internal.Logging
32import org.apache.spark.sql.types._
33import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
34
35/**
36 * This class uses Kafka's own [[KafkaConsumer]] API to read data offsets from Kafka.
37 * The [[ConsumerStrategy]] class defines which Kafka topics and partitions should be read
38 * by this source. These strategies directly correspond to the different consumption options
39 * in. This class is designed to return a configured [[KafkaConsumer]] that is used by the
40 * [[KafkaSource]] to query for the offsets. See the docs on
41 * [[org.apache.spark.sql.kafka010.ConsumerStrategy]]
42 * for more details.
43 *
44 * Note: This class is not ThreadSafe
45 */
46private[kafka010] class KafkaOffsetReader(
47    consumerStrategy: ConsumerStrategy,
48    driverKafkaParams: ju.Map[String, Object],
49    readerOptions: Map[String, String],
50    driverGroupIdPrefix: String) extends Logging {
51  /**
52   * Used to ensure execute fetch operations execute in an UninterruptibleThread
53   */
54  val kafkaReaderThread = Executors.newSingleThreadExecutor(new ThreadFactory {
55    override def newThread(r: Runnable): Thread = {
56      val t = new UninterruptibleThread("Kafka Offset Reader") {
57        override def run(): Unit = {
58          r.run()
59        }
60      }
61      t.setDaemon(true)
62      t
63    }
64  })
65  val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread)
66
67  /**
68   * Place [[groupId]] and [[nextId]] here so that they are initialized before any consumer is
69   * created -- see SPARK-19564.
70   */
71  private var groupId: String = null
72  private var nextId = 0
73
74  /**
75   * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
76   * offsets and never commits them.
77   */
78  protected var consumer = createConsumer()
79
80  private val maxOffsetFetchAttempts =
81    readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt
82
83  private val offsetFetchAttemptIntervalMs =
84    readerOptions.getOrElse("fetchOffset.retryIntervalMs", "1000").toLong
85
86  private def nextGroupId(): String = {
87    groupId = driverGroupIdPrefix + "-" + nextId
88    nextId += 1
89    groupId
90  }
91
92  override def toString(): String = consumerStrategy.toString
93
94  /**
95   * Closes the connection to Kafka, and cleans up state.
96   */
97  def close(): Unit = {
98    consumer.close()
99    kafkaReaderThread.shutdownNow()
100  }
101
102  /**
103   * @return The Set of TopicPartitions for a given topic
104   */
105  def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly {
106    assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
107    // Poll to get the latest assigned partitions
108    consumer.poll(0)
109    val partitions = consumer.assignment()
110    consumer.pause(partitions)
111    partitions.asScala.toSet
112  }
113
114  /**
115   * Resolves the specific offsets based on Kafka seek positions.
116   * This method resolves offset value -1 to the latest and -2 to the
117   * earliest Kafka seek position.
118   */
119  def fetchSpecificOffsets(
120      partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] =
121    runUninterruptibly {
122      withRetriesWithoutInterrupt {
123        // Poll to get the latest assigned partitions
124        consumer.poll(0)
125        val partitions = consumer.assignment()
126        consumer.pause(partitions)
127        assert(partitions.asScala == partitionOffsets.keySet,
128          "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
129            "Use -1 for latest, -2 for earliest, if you don't care.\n" +
130            s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
131        logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
132
133        partitionOffsets.foreach {
134          case (tp, KafkaOffsetRangeLimit.LATEST) =>
135            consumer.seekToEnd(ju.Arrays.asList(tp))
136          case (tp, KafkaOffsetRangeLimit.EARLIEST) =>
137            consumer.seekToBeginning(ju.Arrays.asList(tp))
138          case (tp, off) => consumer.seek(tp, off)
139        }
140        partitionOffsets.map {
141          case (tp, _) => tp -> consumer.position(tp)
142        }
143      }
144    }
145
146  /**
147   * Fetch the earliest offsets for the topic partitions that are indicated
148   * in the [[ConsumerStrategy]].
149   */
150  def fetchEarliestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
151    withRetriesWithoutInterrupt {
152      // Poll to get the latest assigned partitions
153      consumer.poll(0)
154      val partitions = consumer.assignment()
155      consumer.pause(partitions)
156      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning")
157
158      consumer.seekToBeginning(partitions)
159      val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
160      logDebug(s"Got earliest offsets for partition : $partitionOffsets")
161      partitionOffsets
162    }
163  }
164
165  /**
166   * Fetch the latest offsets for the topic partitions that are indicated
167   * in the [[ConsumerStrategy]].
168   */
169  def fetchLatestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
170    withRetriesWithoutInterrupt {
171      // Poll to get the latest assigned partitions
172      consumer.poll(0)
173      val partitions = consumer.assignment()
174      consumer.pause(partitions)
175      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.")
176
177      consumer.seekToEnd(partitions)
178      val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
179      logDebug(s"Got latest offsets for partition : $partitionOffsets")
180      partitionOffsets
181    }
182  }
183
184  /**
185   * Fetch the earliest offsets for specific topic partitions.
186   * The return result may not contain some partitions if they are deleted.
187   */
188  def fetchEarliestOffsets(
189      newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = {
190    if (newPartitions.isEmpty) {
191      Map.empty[TopicPartition, Long]
192    } else {
193      runUninterruptibly {
194        withRetriesWithoutInterrupt {
195          // Poll to get the latest assigned partitions
196          consumer.poll(0)
197          val partitions = consumer.assignment()
198          consumer.pause(partitions)
199          logDebug(s"\tPartitions assigned to consumer: $partitions")
200
201          // Get the earliest offset of each partition
202          consumer.seekToBeginning(partitions)
203          val partitionOffsets = newPartitions.filter { p =>
204            // When deleting topics happen at the same time, some partitions may not be in
205            // `partitions`. So we need to ignore them
206            partitions.contains(p)
207          }.map(p => p -> consumer.position(p)).toMap
208          logDebug(s"Got earliest offsets for new partitions: $partitionOffsets")
209          partitionOffsets
210        }
211      }
212    }
213  }
214
215  /**
216   * This method ensures that the closure is called in an [[UninterruptibleThread]].
217   * This is required when communicating with the [[KafkaConsumer]]. In the case
218   * of streaming queries, we are already running in an [[UninterruptibleThread]],
219   * however for batch mode this is not the case.
220   */
221  private def runUninterruptibly[T](body: => T): T = {
222    if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
223      val future = Future {
224        body
225      }(execContext)
226      ThreadUtils.awaitResult(future, Duration.Inf)
227    } else {
228      body
229    }
230  }
231
232  /**
233   * Helper function that does multiple retries on a body of code that returns offsets.
234   * Retries are needed to handle transient failures. For e.g. race conditions between getting
235   * assignment and getting position while topics/partitions are deleted can cause NPEs.
236   *
237   * This method also makes sure `body` won't be interrupted to workaround a potential issue in
238   * `KafkaConsumer.poll`. (KAFKA-1894)
239   */
240  private def withRetriesWithoutInterrupt(
241      body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
242    // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894)
243    assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
244
245    synchronized {
246      var result: Option[Map[TopicPartition, Long]] = None
247      var attempt = 1
248      var lastException: Throwable = null
249      while (result.isEmpty && attempt <= maxOffsetFetchAttempts
250        && !Thread.currentThread().isInterrupted) {
251        Thread.currentThread match {
252          case ut: UninterruptibleThread =>
253            // "KafkaConsumer.poll" may hang forever if the thread is interrupted (E.g., the query
254            // is stopped)(KAFKA-1894). Hence, we just make sure we don't interrupt it.
255            //
256            // If the broker addresses are wrong, or Kafka cluster is down, "KafkaConsumer.poll" may
257            // hang forever as well. This cannot be resolved in KafkaSource until Kafka fixes the
258            // issue.
259            ut.runUninterruptibly {
260              try {
261                result = Some(body)
262              } catch {
263                case NonFatal(e) =>
264                  lastException = e
265                  logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e)
266                  attempt += 1
267                  Thread.sleep(offsetFetchAttemptIntervalMs)
268                  resetConsumer()
269              }
270            }
271          case _ =>
272            throw new IllegalStateException(
273              "Kafka APIs must be executed on a o.a.spark.util.UninterruptibleThread")
274        }
275      }
276      if (Thread.interrupted()) {
277        throw new InterruptedException()
278      }
279      if (result.isEmpty) {
280        assert(attempt > maxOffsetFetchAttempts)
281        assert(lastException != null)
282        throw lastException
283      }
284      result.get
285    }
286  }
287
288  /**
289   * Create a consumer using the new generated group id. We always use a new consumer to avoid
290   * just using a broken consumer to retry on Kafka errors, which likely will fail again.
291   */
292  private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized {
293    val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
294    newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
295    consumerStrategy.createConsumer(newKafkaParams)
296  }
297
298  private def resetConsumer(): Unit = synchronized {
299    consumer.close()
300    consumer = createConsumer()
301  }
302}
303
304private[kafka010] object KafkaOffsetReader {
305
306  def kafkaSchema: StructType = StructType(Seq(
307    StructField("key", BinaryType),
308    StructField("value", BinaryType),
309    StructField("topic", StringType),
310    StructField("partition", IntegerType),
311    StructField("offset", LongType),
312    StructField("timestamp", TimestampType),
313    StructField("timestampType", IntegerType)
314  ))
315}
316