1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.execution.streaming.state 19 20import java.util.concurrent.{ScheduledFuture, TimeUnit} 21import javax.annotation.concurrent.GuardedBy 22 23import scala.collection.mutable 24import scala.util.control.NonFatal 25 26import org.apache.hadoop.conf.Configuration 27 28import org.apache.spark.SparkEnv 29import org.apache.spark.internal.Logging 30import org.apache.spark.sql.catalyst.expressions.UnsafeRow 31import org.apache.spark.sql.types.StructType 32import org.apache.spark.util.ThreadUtils 33 34 35/** Unique identifier for a [[StateStore]] */ 36case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) 37 38 39/** 40 * Base trait for a versioned key-value store used for streaming aggregations 41 */ 42trait StateStore { 43 44 /** Unique identifier of the store */ 45 def id: StateStoreId 46 47 /** Version of the data in this store before committing updates. */ 48 def version: Long 49 50 /** Get the current value of a key. */ 51 def get(key: UnsafeRow): Option[UnsafeRow] 52 53 /** Put a new value for a key. */ 54 def put(key: UnsafeRow, value: UnsafeRow): Unit 55 56 /** 57 * Remove keys that match the following condition. 58 */ 59 def remove(condition: UnsafeRow => Boolean): Unit 60 61 /** 62 * Commit all the updates that have been made to the store, and return the new version. 63 */ 64 def commit(): Long 65 66 /** Abort all the updates that have been made to the store. */ 67 def abort(): Unit 68 69 /** 70 * Iterator of store data after a set of updates have been committed. 71 * This can be called only after committing all the updates made in the current thread. 72 */ 73 def iterator(): Iterator[(UnsafeRow, UnsafeRow)] 74 75 /** 76 * Iterator of the updates that have been committed. 77 * This can be called only after committing all the updates made in the current thread. 78 */ 79 def updates(): Iterator[StoreUpdate] 80 81 /** Number of keys in the state store */ 82 def numKeys(): Long 83 84 /** 85 * Whether all updates have been committed 86 */ 87 private[streaming] def hasCommitted: Boolean 88} 89 90 91/** Trait representing a provider of a specific version of a [[StateStore]]. */ 92trait StateStoreProvider { 93 94 /** Get the store with the existing version. */ 95 def getStore(version: Long): StateStore 96 97 /** Optional method for providers to allow for background maintenance */ 98 def doMaintenance(): Unit = { } 99} 100 101 102/** Trait representing updates made to a [[StateStore]]. */ 103sealed trait StoreUpdate { 104 def key: UnsafeRow 105 def value: UnsafeRow 106} 107 108case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate 109 110case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate 111 112case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate 113 114 115/** 116 * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores 117 * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), 118 * it also runs a periodic background task to do maintenance on the loaded stores. For each 119 * store, it uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of 120 * the store is the active instance. Accordingly, it either keeps it loaded and performs 121 * maintenance, or unloads the store. 122 */ 123object StateStore extends Logging { 124 125 val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" 126 val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 127 128 @GuardedBy("loadedProviders") 129 private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() 130 131 /** 132 * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` 133 * will be called when an exception happens. 134 */ 135 class MaintenanceTask(periodMs: Long, task: => Unit, onError: => Unit) { 136 private val executor = 137 ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") 138 139 private val runnable = new Runnable { 140 override def run(): Unit = { 141 try { 142 task 143 } catch { 144 case NonFatal(e) => 145 logWarning("Error running maintenance thread", e) 146 onError 147 throw e 148 } 149 } 150 } 151 152 private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate( 153 runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) 154 155 def stop(): Unit = { 156 future.cancel(false) 157 executor.shutdown() 158 } 159 160 def isRunning: Boolean = !future.isDone 161 } 162 163 @GuardedBy("loadedProviders") 164 private var maintenanceTask: MaintenanceTask = null 165 166 @GuardedBy("loadedProviders") 167 private var _coordRef: StateStoreCoordinatorRef = null 168 169 /** Get or create a store associated with the id. */ 170 def get( 171 storeId: StateStoreId, 172 keySchema: StructType, 173 valueSchema: StructType, 174 version: Long, 175 storeConf: StateStoreConf, 176 hadoopConf: Configuration): StateStore = { 177 require(version >= 0) 178 val storeProvider = loadedProviders.synchronized { 179 startMaintenanceIfNeeded() 180 val provider = loadedProviders.getOrElseUpdate( 181 storeId, 182 new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) 183 reportActiveStoreInstance(storeId) 184 provider 185 } 186 storeProvider.getStore(version) 187 } 188 189 /** Unload a state store provider */ 190 def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { 191 loadedProviders.remove(storeId) 192 } 193 194 /** Whether a state store provider is loaded or not */ 195 def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { 196 loadedProviders.contains(storeId) 197 } 198 199 def isMaintenanceRunning: Boolean = loadedProviders.synchronized { 200 maintenanceTask != null && maintenanceTask.isRunning 201 } 202 203 /** Unload and stop all state store providers */ 204 def stop(): Unit = loadedProviders.synchronized { 205 loadedProviders.clear() 206 _coordRef = null 207 if (maintenanceTask != null) { 208 maintenanceTask.stop() 209 maintenanceTask = null 210 } 211 logInfo("StateStore stopped") 212 } 213 214 /** Start the periodic maintenance task if not already started and if Spark active */ 215 private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { 216 val env = SparkEnv.get 217 if (env != null && !isMaintenanceRunning) { 218 val periodMs = env.conf.getTimeAsMs( 219 MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") 220 maintenanceTask = new MaintenanceTask( 221 periodMs, 222 task = { doMaintenance() }, 223 onError = { loadedProviders.synchronized { loadedProviders.clear() } } 224 ) 225 logInfo("State Store maintenance task started") 226 } 227 } 228 229 /** 230 * Execute background maintenance task in all the loaded store providers if they are still 231 * the active instances according to the coordinator. 232 */ 233 private def doMaintenance(): Unit = { 234 logDebug("Doing maintenance") 235 if (SparkEnv.get == null) { 236 throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores") 237 } 238 loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => 239 try { 240 if (verifyIfStoreInstanceActive(id)) { 241 provider.doMaintenance() 242 } else { 243 unload(id) 244 logInfo(s"Unloaded $provider") 245 } 246 } catch { 247 case NonFatal(e) => 248 logWarning(s"Error managing $provider, stopping management thread") 249 throw e 250 } 251 } 252 } 253 254 private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { 255 if (SparkEnv.get != null) { 256 val host = SparkEnv.get.blockManager.blockManagerId.host 257 val executorId = SparkEnv.get.blockManager.blockManagerId.executorId 258 coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) 259 logDebug(s"Reported that the loaded instance $storeId is active") 260 } 261 } 262 263 private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { 264 if (SparkEnv.get != null) { 265 val executorId = SparkEnv.get.blockManager.blockManagerId.executorId 266 val verified = 267 coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) 268 logDebug(s"Verified whether the loaded instance $storeId is active: $verified") 269 verified 270 } else { 271 false 272 } 273 } 274 275 private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { 276 val env = SparkEnv.get 277 if (env != null) { 278 if (_coordRef == null) { 279 _coordRef = StateStoreCoordinatorRef.forExecutor(env) 280 } 281 logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") 282 Some(_coordRef) 283 } else { 284 _coordRef = null 285 None 286 } 287 } 288} 289 290