1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.util 19 20import java.io.IOException 21 22import org.apache.hadoop.fs.Path 23import org.json4s._ 24import org.json4s.{DefaultFormats, JObject} 25import org.json4s.JsonDSL._ 26import org.json4s.jackson.JsonMethods._ 27 28import org.apache.spark.SparkContext 29import org.apache.spark.annotation.{DeveloperApi, Since} 30import org.apache.spark.internal.Logging 31import org.apache.spark.ml._ 32import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} 33import org.apache.spark.ml.feature.RFormulaModel 34import org.apache.spark.ml.param.{ParamPair, Params} 35import org.apache.spark.ml.tuning.ValidatorParams 36import org.apache.spark.sql.{SparkSession, SQLContext} 37import org.apache.spark.util.Utils 38 39/** 40 * Trait for [[MLWriter]] and [[MLReader]]. 41 */ 42private[util] sealed trait BaseReadWrite { 43 private var optionSparkSession: Option[SparkSession] = None 44 45 /** 46 * Sets the Spark SQLContext to use for saving/loading. 47 */ 48 @Since("1.6.0") 49 @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0") 50 def context(sqlContext: SQLContext): this.type = { 51 optionSparkSession = Option(sqlContext.sparkSession) 52 this 53 } 54 55 /** 56 * Sets the Spark Session to use for saving/loading. 57 */ 58 @Since("2.0.0") 59 def session(sparkSession: SparkSession): this.type = { 60 optionSparkSession = Option(sparkSession) 61 this 62 } 63 64 /** 65 * Returns the user-specified Spark Session or the default. 66 */ 67 protected final def sparkSession: SparkSession = { 68 if (optionSparkSession.isEmpty) { 69 optionSparkSession = Some(SparkSession.builder().getOrCreate()) 70 } 71 optionSparkSession.get 72 } 73 74 /** 75 * Returns the user-specified SQL context or the default. 76 */ 77 protected final def sqlContext: SQLContext = sparkSession.sqlContext 78 79 /** Returns the underlying `SparkContext`. */ 80 protected final def sc: SparkContext = sparkSession.sparkContext 81} 82 83/** 84 * Abstract class for utility classes that can save ML instances. 85 */ 86@Since("1.6.0") 87abstract class MLWriter extends BaseReadWrite with Logging { 88 89 protected var shouldOverwrite: Boolean = false 90 91 /** 92 * Saves the ML instances to the input path. 93 */ 94 @Since("1.6.0") 95 @throws[IOException]("If the input path already exists but overwrite is not enabled.") 96 def save(path: String): Unit = { 97 val hadoopConf = sc.hadoopConfiguration 98 val outputPath = new Path(path) 99 val fs = outputPath.getFileSystem(hadoopConf) 100 val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) 101 if (fs.exists(qualifiedOutputPath)) { 102 if (shouldOverwrite) { 103 logInfo(s"Path $path already exists. It will be overwritten.") 104 // TODO: Revert back to the original content if save is not successful. 105 fs.delete(qualifiedOutputPath, true) 106 } else { 107 throw new IOException( 108 s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") 109 } 110 } 111 saveImpl(path) 112 } 113 114 /** 115 * [[save()]] handles overwriting and then calls this method. Subclasses should override this 116 * method to implement the actual saving of the instance. 117 */ 118 @Since("1.6.0") 119 protected def saveImpl(path: String): Unit 120 121 /** 122 * Overwrites if the output path already exists. 123 */ 124 @Since("1.6.0") 125 def overwrite(): this.type = { 126 shouldOverwrite = true 127 this 128 } 129 130 // override for Java compatibility 131 override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) 132 133 // override for Java compatibility 134 override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) 135} 136 137/** 138 * Trait for classes that provide [[MLWriter]]. 139 */ 140@Since("1.6.0") 141trait MLWritable { 142 143 /** 144 * Returns an [[MLWriter]] instance for this ML instance. 145 */ 146 @Since("1.6.0") 147 def write: MLWriter 148 149 /** 150 * Saves this ML instance to the input path, a shortcut of `write.save(path)`. 151 */ 152 @Since("1.6.0") 153 @throws[IOException]("If the input path already exists but overwrite is not enabled.") 154 def save(path: String): Unit = write.save(path) 155} 156 157/** 158 * :: DeveloperApi :: 159 * 160 * Helper trait for making simple `Params` types writable. If a `Params` class stores 161 * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide 162 * a default implementation of writing saved instances of the class. 163 * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle 164 * [[org.apache.spark.sql.Dataset]]. 165 * 166 * @see `DefaultParamsReadable`, the counterpart to this trait 167 */ 168@DeveloperApi 169trait DefaultParamsWritable extends MLWritable { self: Params => 170 171 override def write: MLWriter = new DefaultParamsWriter(this) 172} 173 174/** 175 * Abstract class for utility classes that can load ML instances. 176 * 177 * @tparam T ML instance type 178 */ 179@Since("1.6.0") 180abstract class MLReader[T] extends BaseReadWrite { 181 182 /** 183 * Loads the ML component from the input path. 184 */ 185 @Since("1.6.0") 186 def load(path: String): T 187 188 // override for Java compatibility 189 override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) 190 191 // override for Java compatibility 192 override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) 193} 194 195/** 196 * Trait for objects that provide [[MLReader]]. 197 * 198 * @tparam T ML instance type 199 */ 200@Since("1.6.0") 201trait MLReadable[T] { 202 203 /** 204 * Returns an [[MLReader]] instance for this class. 205 */ 206 @Since("1.6.0") 207 def read: MLReader[T] 208 209 /** 210 * Reads an ML instance from the input path, a shortcut of `read.load(path)`. 211 * 212 * @note Implementing classes should override this to be Java-friendly. 213 */ 214 @Since("1.6.0") 215 def load(path: String): T = read.load(path) 216} 217 218 219/** 220 * :: DeveloperApi :: 221 * 222 * Helper trait for making simple `Params` types readable. If a `Params` class stores 223 * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide 224 * a default implementation of reading saved instances of the class. 225 * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle 226 * [[org.apache.spark.sql.Dataset]]. 227 * 228 * @tparam T ML instance type 229 * @see `DefaultParamsWritable`, the counterpart to this trait 230 */ 231@DeveloperApi 232trait DefaultParamsReadable[T] extends MLReadable[T] { 233 234 override def read: MLReader[T] = new DefaultParamsReader[T] 235} 236 237/** 238 * Default [[MLWriter]] implementation for transformers and estimators that contain basic 239 * (json4s-serializable) params and no data. This will not handle more complex params or types with 240 * data (e.g., models with coefficients). 241 * 242 * @param instance object to save 243 */ 244private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { 245 246 override protected def saveImpl(path: String): Unit = { 247 DefaultParamsWriter.saveMetadata(instance, path, sc) 248 } 249} 250 251private[ml] object DefaultParamsWriter { 252 253 /** 254 * Saves metadata + Params to: path + "/metadata" 255 * - class 256 * - timestamp 257 * - sparkVersion 258 * - uid 259 * - paramMap 260 * - (optionally, extra metadata) 261 * 262 * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. 263 * @param paramMap If given, this is saved in the "paramMap" field. 264 * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using 265 * [[org.apache.spark.ml.param.Param.jsonEncode()]]. 266 */ 267 def saveMetadata( 268 instance: Params, 269 path: String, 270 sc: SparkContext, 271 extraMetadata: Option[JObject] = None, 272 paramMap: Option[JValue] = None): Unit = { 273 val metadataPath = new Path(path, "metadata").toString 274 val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) 275 sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) 276 } 277 278 /** 279 * Helper for [[saveMetadata()]] which extracts the JSON to save. 280 * This is useful for ensemble models which need to save metadata for many sub-models. 281 * 282 * @see [[saveMetadata()]] for details on what this includes. 283 */ 284 def getMetadataToSave( 285 instance: Params, 286 sc: SparkContext, 287 extraMetadata: Option[JObject] = None, 288 paramMap: Option[JValue] = None): String = { 289 val uid = instance.uid 290 val cls = instance.getClass.getName 291 val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] 292 val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => 293 p.name -> parse(p.jsonEncode(v)) 294 }.toList)) 295 val basicMetadata = ("class" -> cls) ~ 296 ("timestamp" -> System.currentTimeMillis()) ~ 297 ("sparkVersion" -> sc.version) ~ 298 ("uid" -> uid) ~ 299 ("paramMap" -> jsonParams) 300 val metadata = extraMetadata match { 301 case Some(jObject) => 302 basicMetadata ~ jObject 303 case None => 304 basicMetadata 305 } 306 val metadataJson: String = compact(render(metadata)) 307 metadataJson 308 } 309} 310 311/** 312 * Default [[MLReader]] implementation for transformers and estimators that contain basic 313 * (json4s-serializable) params and no data. This will not handle more complex params or types with 314 * data (e.g., models with coefficients). 315 * 316 * @tparam T ML instance type 317 * TODO: Consider adding check for correct class name. 318 */ 319private[ml] class DefaultParamsReader[T] extends MLReader[T] { 320 321 override def load(path: String): T = { 322 val metadata = DefaultParamsReader.loadMetadata(path, sc) 323 val cls = Utils.classForName(metadata.className) 324 val instance = 325 cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] 326 DefaultParamsReader.getAndSetParams(instance, metadata) 327 instance.asInstanceOf[T] 328 } 329} 330 331private[ml] object DefaultParamsReader { 332 333 /** 334 * All info from metadata file. 335 * 336 * @param params paramMap, as a `JValue` 337 * @param metadata All metadata, including the other fields 338 * @param metadataJson Full metadata file String (for debugging) 339 */ 340 case class Metadata( 341 className: String, 342 uid: String, 343 timestamp: Long, 344 sparkVersion: String, 345 params: JValue, 346 metadata: JValue, 347 metadataJson: String) { 348 349 /** 350 * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. 351 * This can be useful for getting a Param value before an instance of `Params` 352 * is available. 353 */ 354 def getParamValue(paramName: String): JValue = { 355 implicit val format = DefaultFormats 356 params match { 357 case JObject(pairs) => 358 val values = pairs.filter { case (pName, jsonValue) => 359 pName == paramName 360 }.map(_._2) 361 assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + 362 s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) 363 values.head 364 case _ => 365 throw new IllegalArgumentException( 366 s"Cannot recognize JSON metadata: $metadataJson.") 367 } 368 } 369 } 370 371 /** 372 * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]] 373 * 374 * @param expectedClassName If non empty, this is checked against the loaded metadata. 375 * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata 376 */ 377 def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { 378 val metadataPath = new Path(path, "metadata").toString 379 val metadataStr = sc.textFile(metadataPath, 1).first() 380 parseMetadata(metadataStr, expectedClassName) 381 } 382 383 /** 384 * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]]. 385 * This is a helper function for [[loadMetadata()]]. 386 * 387 * @param metadataStr JSON string of metadata 388 * @param expectedClassName If non empty, this is checked against the loaded metadata. 389 * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata 390 */ 391 def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = { 392 val metadata = parse(metadataStr) 393 394 implicit val format = DefaultFormats 395 val className = (metadata \ "class").extract[String] 396 val uid = (metadata \ "uid").extract[String] 397 val timestamp = (metadata \ "timestamp").extract[Long] 398 val sparkVersion = (metadata \ "sparkVersion").extract[String] 399 val params = metadata \ "paramMap" 400 if (expectedClassName.nonEmpty) { 401 require(className == expectedClassName, s"Error loading metadata: Expected class name" + 402 s" $expectedClassName but found class name $className") 403 } 404 405 Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) 406 } 407 408 /** 409 * Extract Params from metadata, and set them in the instance. 410 * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. 411 * TODO: Move to [[Metadata]] method 412 */ 413 def getAndSetParams(instance: Params, metadata: Metadata): Unit = { 414 implicit val format = DefaultFormats 415 metadata.params match { 416 case JObject(pairs) => 417 pairs.foreach { case (paramName, jsonValue) => 418 val param = instance.getParam(paramName) 419 val value = param.jsonDecode(compact(render(jsonValue))) 420 instance.set(param, value) 421 } 422 case _ => 423 throw new IllegalArgumentException( 424 s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") 425 } 426 } 427 428 /** 429 * Load a `Params` instance from the given path, and return it. 430 * This assumes the instance implements [[MLReadable]]. 431 */ 432 def loadParamsInstance[T](path: String, sc: SparkContext): T = { 433 val metadata = DefaultParamsReader.loadMetadata(path, sc) 434 val cls = Utils.classForName(metadata.className) 435 cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) 436 } 437} 438 439/** 440 * Default Meta-Algorithm read and write implementation. 441 */ 442private[ml] object MetaAlgorithmReadWrite { 443 /** 444 * Examine the given estimator (which may be a compound estimator) and extract a mapping 445 * from UIDs to corresponding `Params` instances. 446 */ 447 def getUidMap(instance: Params): Map[String, Params] = { 448 val uidList = getUidMapImpl(instance) 449 val uidMap = uidList.toMap 450 if (uidList.size != uidMap.size) { 451 throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" + 452 s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.") 453 } 454 uidMap 455 } 456 457 private def getUidMapImpl(instance: Params): List[(String, Params)] = { 458 val subStages: Array[Params] = instance match { 459 case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] 460 case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] 461 case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) 462 case ovr: OneVsRest => Array(ovr.getClassifier) 463 case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models 464 case rformModel: RFormulaModel => Array(rformModel.pipelineModel) 465 case _: Params => Array.empty[Params] 466 } 467 val subStageMaps = subStages.flatMap(getUidMapImpl) 468 List((instance.uid, instance)) ++ subStageMaps 469 } 470} 471