1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.execution.streaming.state
19
20import java.util.concurrent.{ScheduledFuture, TimeUnit}
21import javax.annotation.concurrent.GuardedBy
22
23import scala.collection.mutable
24import scala.util.control.NonFatal
25
26import org.apache.hadoop.conf.Configuration
27
28import org.apache.spark.SparkEnv
29import org.apache.spark.internal.Logging
30import org.apache.spark.sql.catalyst.expressions.UnsafeRow
31import org.apache.spark.sql.types.StructType
32import org.apache.spark.util.ThreadUtils
33
34
35/** Unique identifier for a [[StateStore]] */
36case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int)
37
38
39/**
40 * Base trait for a versioned key-value store used for streaming aggregations
41 */
42trait StateStore {
43
44  /** Unique identifier of the store */
45  def id: StateStoreId
46
47  /** Version of the data in this store before committing updates. */
48  def version: Long
49
50  /** Get the current value of a key. */
51  def get(key: UnsafeRow): Option[UnsafeRow]
52
53  /** Put a new value for a key. */
54  def put(key: UnsafeRow, value: UnsafeRow): Unit
55
56  /**
57   * Remove keys that match the following condition.
58   */
59  def remove(condition: UnsafeRow => Boolean): Unit
60
61  /**
62   * Commit all the updates that have been made to the store, and return the new version.
63   */
64  def commit(): Long
65
66  /** Abort all the updates that have been made to the store. */
67  def abort(): Unit
68
69  /**
70   * Iterator of store data after a set of updates have been committed.
71   * This can be called only after committing all the updates made in the current thread.
72   */
73  def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
74
75  /**
76   * Iterator of the updates that have been committed.
77   * This can be called only after committing all the updates made in the current thread.
78   */
79  def updates(): Iterator[StoreUpdate]
80
81  /** Number of keys in the state store */
82  def numKeys(): Long
83
84  /**
85   * Whether all updates have been committed
86   */
87  private[streaming] def hasCommitted: Boolean
88}
89
90
91/** Trait representing a provider of a specific version of a [[StateStore]]. */
92trait StateStoreProvider {
93
94  /** Get the store with the existing version. */
95  def getStore(version: Long): StateStore
96
97  /** Optional method for providers to allow for background maintenance */
98  def doMaintenance(): Unit = { }
99}
100
101
102/** Trait representing updates made to a [[StateStore]]. */
103sealed trait StoreUpdate {
104  def key: UnsafeRow
105  def value: UnsafeRow
106}
107
108case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
109
110case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
111
112case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
113
114
115/**
116 * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
117 * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
118 * it also runs a periodic background task to do maintenance on the loaded stores. For each
119 * store, it uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of
120 * the store is the active instance. Accordingly, it either keeps it loaded and performs
121 * maintenance, or unloads the store.
122 */
123object StateStore extends Logging {
124
125  val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval"
126  val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
127
128  @GuardedBy("loadedProviders")
129  private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
130
131  /**
132   * Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
133   * will be called when an exception happens.
134   */
135  class MaintenanceTask(periodMs: Long, task: => Unit, onError: => Unit) {
136    private val executor =
137      ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
138
139    private val runnable = new Runnable {
140      override def run(): Unit = {
141        try {
142          task
143        } catch {
144          case NonFatal(e) =>
145            logWarning("Error running maintenance thread", e)
146            onError
147            throw e
148        }
149      }
150    }
151
152    private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate(
153      runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
154
155    def stop(): Unit = {
156      future.cancel(false)
157      executor.shutdown()
158    }
159
160    def isRunning: Boolean = !future.isDone
161  }
162
163  @GuardedBy("loadedProviders")
164  private var maintenanceTask: MaintenanceTask = null
165
166  @GuardedBy("loadedProviders")
167  private var _coordRef: StateStoreCoordinatorRef = null
168
169  /** Get or create a store associated with the id. */
170  def get(
171      storeId: StateStoreId,
172      keySchema: StructType,
173      valueSchema: StructType,
174      version: Long,
175      storeConf: StateStoreConf,
176      hadoopConf: Configuration): StateStore = {
177    require(version >= 0)
178    val storeProvider = loadedProviders.synchronized {
179      startMaintenanceIfNeeded()
180      val provider = loadedProviders.getOrElseUpdate(
181        storeId,
182        new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf))
183      reportActiveStoreInstance(storeId)
184      provider
185    }
186    storeProvider.getStore(version)
187  }
188
189  /** Unload a state store provider */
190  def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
191    loadedProviders.remove(storeId)
192  }
193
194  /** Whether a state store provider is loaded or not */
195  def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized {
196    loadedProviders.contains(storeId)
197  }
198
199  def isMaintenanceRunning: Boolean = loadedProviders.synchronized {
200    maintenanceTask != null && maintenanceTask.isRunning
201  }
202
203  /** Unload and stop all state store providers */
204  def stop(): Unit = loadedProviders.synchronized {
205    loadedProviders.clear()
206    _coordRef = null
207    if (maintenanceTask != null) {
208      maintenanceTask.stop()
209      maintenanceTask = null
210    }
211    logInfo("StateStore stopped")
212  }
213
214  /** Start the periodic maintenance task if not already started and if Spark active */
215  private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized {
216    val env = SparkEnv.get
217    if (env != null && !isMaintenanceRunning) {
218      val periodMs = env.conf.getTimeAsMs(
219        MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s")
220      maintenanceTask = new MaintenanceTask(
221        periodMs,
222        task = { doMaintenance() },
223        onError = { loadedProviders.synchronized { loadedProviders.clear() } }
224      )
225      logInfo("State Store maintenance task started")
226    }
227  }
228
229  /**
230   * Execute background maintenance task in all the loaded store providers if they are still
231   * the active instances according to the coordinator.
232   */
233  private def doMaintenance(): Unit = {
234    logDebug("Doing maintenance")
235    if (SparkEnv.get == null) {
236      throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores")
237    }
238    loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
239      try {
240        if (verifyIfStoreInstanceActive(id)) {
241          provider.doMaintenance()
242        } else {
243          unload(id)
244          logInfo(s"Unloaded $provider")
245        }
246      } catch {
247        case NonFatal(e) =>
248          logWarning(s"Error managing $provider, stopping management thread")
249          throw e
250      }
251    }
252  }
253
254  private def reportActiveStoreInstance(storeId: StateStoreId): Unit = {
255    if (SparkEnv.get != null) {
256      val host = SparkEnv.get.blockManager.blockManagerId.host
257      val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
258      coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId))
259      logDebug(s"Reported that the loaded instance $storeId is active")
260    }
261  }
262
263  private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = {
264    if (SparkEnv.get != null) {
265      val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
266      val verified =
267        coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false)
268      logDebug(s"Verified whether the loaded instance $storeId is active: $verified")
269      verified
270    } else {
271      false
272    }
273  }
274
275  private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
276    val env = SparkEnv.get
277    if (env != null) {
278      if (_coordRef == null) {
279        _coordRef = StateStoreCoordinatorRef.forExecutor(env)
280      }
281      logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
282      Some(_coordRef)
283    } else {
284      _coordRef = null
285      None
286    }
287  }
288}
289
290