1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.util
19
20import java.io.IOException
21
22import org.apache.hadoop.fs.Path
23import org.json4s._
24import org.json4s.{DefaultFormats, JObject}
25import org.json4s.JsonDSL._
26import org.json4s.jackson.JsonMethods._
27
28import org.apache.spark.SparkContext
29import org.apache.spark.annotation.{DeveloperApi, Since}
30import org.apache.spark.internal.Logging
31import org.apache.spark.ml._
32import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
33import org.apache.spark.ml.feature.RFormulaModel
34import org.apache.spark.ml.param.{ParamPair, Params}
35import org.apache.spark.ml.tuning.ValidatorParams
36import org.apache.spark.sql.{SparkSession, SQLContext}
37import org.apache.spark.util.Utils
38
39/**
40 * Trait for [[MLWriter]] and [[MLReader]].
41 */
42private[util] sealed trait BaseReadWrite {
43  private var optionSparkSession: Option[SparkSession] = None
44
45  /**
46   * Sets the Spark SQLContext to use for saving/loading.
47   */
48  @Since("1.6.0")
49  @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0")
50  def context(sqlContext: SQLContext): this.type = {
51    optionSparkSession = Option(sqlContext.sparkSession)
52    this
53  }
54
55  /**
56   * Sets the Spark Session to use for saving/loading.
57   */
58  @Since("2.0.0")
59  def session(sparkSession: SparkSession): this.type = {
60    optionSparkSession = Option(sparkSession)
61    this
62  }
63
64  /**
65   * Returns the user-specified Spark Session or the default.
66   */
67  protected final def sparkSession: SparkSession = {
68    if (optionSparkSession.isEmpty) {
69      optionSparkSession = Some(SparkSession.builder().getOrCreate())
70    }
71    optionSparkSession.get
72  }
73
74  /**
75   * Returns the user-specified SQL context or the default.
76   */
77  protected final def sqlContext: SQLContext = sparkSession.sqlContext
78
79  /** Returns the underlying `SparkContext`. */
80  protected final def sc: SparkContext = sparkSession.sparkContext
81}
82
83/**
84 * Abstract class for utility classes that can save ML instances.
85 */
86@Since("1.6.0")
87abstract class MLWriter extends BaseReadWrite with Logging {
88
89  protected var shouldOverwrite: Boolean = false
90
91  /**
92   * Saves the ML instances to the input path.
93   */
94  @Since("1.6.0")
95  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
96  def save(path: String): Unit = {
97    val hadoopConf = sc.hadoopConfiguration
98    val outputPath = new Path(path)
99    val fs = outputPath.getFileSystem(hadoopConf)
100    val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
101    if (fs.exists(qualifiedOutputPath)) {
102      if (shouldOverwrite) {
103        logInfo(s"Path $path already exists. It will be overwritten.")
104        // TODO: Revert back to the original content if save is not successful.
105        fs.delete(qualifiedOutputPath, true)
106      } else {
107        throw new IOException(
108          s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
109      }
110    }
111    saveImpl(path)
112  }
113
114  /**
115   * [[save()]] handles overwriting and then calls this method.  Subclasses should override this
116   * method to implement the actual saving of the instance.
117   */
118  @Since("1.6.0")
119  protected def saveImpl(path: String): Unit
120
121  /**
122   * Overwrites if the output path already exists.
123   */
124  @Since("1.6.0")
125  def overwrite(): this.type = {
126    shouldOverwrite = true
127    this
128  }
129
130  // override for Java compatibility
131  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
132
133  // override for Java compatibility
134  override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
135}
136
137/**
138 * Trait for classes that provide [[MLWriter]].
139 */
140@Since("1.6.0")
141trait MLWritable {
142
143  /**
144   * Returns an [[MLWriter]] instance for this ML instance.
145   */
146  @Since("1.6.0")
147  def write: MLWriter
148
149  /**
150   * Saves this ML instance to the input path, a shortcut of `write.save(path)`.
151   */
152  @Since("1.6.0")
153  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
154  def save(path: String): Unit = write.save(path)
155}
156
157/**
158 * :: DeveloperApi ::
159 *
160 * Helper trait for making simple `Params` types writable.  If a `Params` class stores
161 * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide
162 * a default implementation of writing saved instances of the class.
163 * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle
164 * [[org.apache.spark.sql.Dataset]].
165 *
166 * @see `DefaultParamsReadable`, the counterpart to this trait
167 */
168@DeveloperApi
169trait DefaultParamsWritable extends MLWritable { self: Params =>
170
171  override def write: MLWriter = new DefaultParamsWriter(this)
172}
173
174/**
175 * Abstract class for utility classes that can load ML instances.
176 *
177 * @tparam T ML instance type
178 */
179@Since("1.6.0")
180abstract class MLReader[T] extends BaseReadWrite {
181
182  /**
183   * Loads the ML component from the input path.
184   */
185  @Since("1.6.0")
186  def load(path: String): T
187
188  // override for Java compatibility
189  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
190
191  // override for Java compatibility
192  override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
193}
194
195/**
196 * Trait for objects that provide [[MLReader]].
197 *
198 * @tparam T ML instance type
199 */
200@Since("1.6.0")
201trait MLReadable[T] {
202
203  /**
204   * Returns an [[MLReader]] instance for this class.
205   */
206  @Since("1.6.0")
207  def read: MLReader[T]
208
209  /**
210   * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
211   *
212   * @note Implementing classes should override this to be Java-friendly.
213   */
214  @Since("1.6.0")
215  def load(path: String): T = read.load(path)
216}
217
218
219/**
220 * :: DeveloperApi ::
221 *
222 * Helper trait for making simple `Params` types readable.  If a `Params` class stores
223 * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide
224 * a default implementation of reading saved instances of the class.
225 * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle
226 * [[org.apache.spark.sql.Dataset]].
227 *
228 * @tparam T ML instance type
229 * @see `DefaultParamsWritable`, the counterpart to this trait
230 */
231@DeveloperApi
232trait DefaultParamsReadable[T] extends MLReadable[T] {
233
234  override def read: MLReader[T] = new DefaultParamsReader[T]
235}
236
237/**
238 * Default [[MLWriter]] implementation for transformers and estimators that contain basic
239 * (json4s-serializable) params and no data. This will not handle more complex params or types with
240 * data (e.g., models with coefficients).
241 *
242 * @param instance object to save
243 */
244private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
245
246  override protected def saveImpl(path: String): Unit = {
247    DefaultParamsWriter.saveMetadata(instance, path, sc)
248  }
249}
250
251private[ml] object DefaultParamsWriter {
252
253  /**
254   * Saves metadata + Params to: path + "/metadata"
255   *  - class
256   *  - timestamp
257   *  - sparkVersion
258   *  - uid
259   *  - paramMap
260   *  - (optionally, extra metadata)
261   *
262   * @param extraMetadata  Extra metadata to be saved at same level as uid, paramMap, etc.
263   * @param paramMap  If given, this is saved in the "paramMap" field.
264   *                  Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
265   *                  [[org.apache.spark.ml.param.Param.jsonEncode()]].
266   */
267  def saveMetadata(
268      instance: Params,
269      path: String,
270      sc: SparkContext,
271      extraMetadata: Option[JObject] = None,
272      paramMap: Option[JValue] = None): Unit = {
273    val metadataPath = new Path(path, "metadata").toString
274    val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
275    sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
276  }
277
278  /**
279   * Helper for [[saveMetadata()]] which extracts the JSON to save.
280   * This is useful for ensemble models which need to save metadata for many sub-models.
281   *
282   * @see [[saveMetadata()]] for details on what this includes.
283   */
284  def getMetadataToSave(
285      instance: Params,
286      sc: SparkContext,
287      extraMetadata: Option[JObject] = None,
288      paramMap: Option[JValue] = None): String = {
289    val uid = instance.uid
290    val cls = instance.getClass.getName
291    val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
292    val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
293      p.name -> parse(p.jsonEncode(v))
294    }.toList))
295    val basicMetadata = ("class" -> cls) ~
296      ("timestamp" -> System.currentTimeMillis()) ~
297      ("sparkVersion" -> sc.version) ~
298      ("uid" -> uid) ~
299      ("paramMap" -> jsonParams)
300    val metadata = extraMetadata match {
301      case Some(jObject) =>
302        basicMetadata ~ jObject
303      case None =>
304        basicMetadata
305    }
306    val metadataJson: String = compact(render(metadata))
307    metadataJson
308  }
309}
310
311/**
312 * Default [[MLReader]] implementation for transformers and estimators that contain basic
313 * (json4s-serializable) params and no data. This will not handle more complex params or types with
314 * data (e.g., models with coefficients).
315 *
316 * @tparam T ML instance type
317 * TODO: Consider adding check for correct class name.
318 */
319private[ml] class DefaultParamsReader[T] extends MLReader[T] {
320
321  override def load(path: String): T = {
322    val metadata = DefaultParamsReader.loadMetadata(path, sc)
323    val cls = Utils.classForName(metadata.className)
324    val instance =
325      cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
326    DefaultParamsReader.getAndSetParams(instance, metadata)
327    instance.asInstanceOf[T]
328  }
329}
330
331private[ml] object DefaultParamsReader {
332
333  /**
334   * All info from metadata file.
335   *
336   * @param params  paramMap, as a `JValue`
337   * @param metadata  All metadata, including the other fields
338   * @param metadataJson  Full metadata file String (for debugging)
339   */
340  case class Metadata(
341      className: String,
342      uid: String,
343      timestamp: Long,
344      sparkVersion: String,
345      params: JValue,
346      metadata: JValue,
347      metadataJson: String) {
348
349    /**
350     * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
351     * This can be useful for getting a Param value before an instance of `Params`
352     * is available.
353     */
354    def getParamValue(paramName: String): JValue = {
355      implicit val format = DefaultFormats
356      params match {
357        case JObject(pairs) =>
358          val values = pairs.filter { case (pName, jsonValue) =>
359            pName == paramName
360          }.map(_._2)
361          assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
362            s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
363          values.head
364        case _ =>
365          throw new IllegalArgumentException(
366            s"Cannot recognize JSON metadata: $metadataJson.")
367      }
368    }
369  }
370
371  /**
372   * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
373   *
374   * @param expectedClassName  If non empty, this is checked against the loaded metadata.
375   * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
376   */
377  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
378    val metadataPath = new Path(path, "metadata").toString
379    val metadataStr = sc.textFile(metadataPath, 1).first()
380    parseMetadata(metadataStr, expectedClassName)
381  }
382
383  /**
384   * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
385   * This is a helper function for [[loadMetadata()]].
386   *
387   * @param metadataStr  JSON string of metadata
388   * @param expectedClassName  If non empty, this is checked against the loaded metadata.
389   * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
390   */
391  def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
392    val metadata = parse(metadataStr)
393
394    implicit val format = DefaultFormats
395    val className = (metadata \ "class").extract[String]
396    val uid = (metadata \ "uid").extract[String]
397    val timestamp = (metadata \ "timestamp").extract[Long]
398    val sparkVersion = (metadata \ "sparkVersion").extract[String]
399    val params = metadata \ "paramMap"
400    if (expectedClassName.nonEmpty) {
401      require(className == expectedClassName, s"Error loading metadata: Expected class name" +
402        s" $expectedClassName but found class name $className")
403    }
404
405    Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
406  }
407
408  /**
409   * Extract Params from metadata, and set them in the instance.
410   * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
411   * TODO: Move to [[Metadata]] method
412   */
413  def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
414    implicit val format = DefaultFormats
415    metadata.params match {
416      case JObject(pairs) =>
417        pairs.foreach { case (paramName, jsonValue) =>
418          val param = instance.getParam(paramName)
419          val value = param.jsonDecode(compact(render(jsonValue)))
420          instance.set(param, value)
421        }
422      case _ =>
423        throw new IllegalArgumentException(
424          s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
425    }
426  }
427
428  /**
429   * Load a `Params` instance from the given path, and return it.
430   * This assumes the instance implements [[MLReadable]].
431   */
432  def loadParamsInstance[T](path: String, sc: SparkContext): T = {
433    val metadata = DefaultParamsReader.loadMetadata(path, sc)
434    val cls = Utils.classForName(metadata.className)
435    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
436  }
437}
438
439/**
440 * Default Meta-Algorithm read and write implementation.
441 */
442private[ml] object MetaAlgorithmReadWrite {
443  /**
444   * Examine the given estimator (which may be a compound estimator) and extract a mapping
445   * from UIDs to corresponding `Params` instances.
446   */
447  def getUidMap(instance: Params): Map[String, Params] = {
448    val uidList = getUidMapImpl(instance)
449    val uidMap = uidList.toMap
450    if (uidList.size != uidMap.size) {
451      throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
452        s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
453    }
454    uidMap
455  }
456
457  private def getUidMapImpl(instance: Params): List[(String, Params)] = {
458    val subStages: Array[Params] = instance match {
459      case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
460      case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
461      case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
462      case ovr: OneVsRest => Array(ovr.getClassifier)
463      case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
464      case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
465      case _: Params => Array.empty[Params]
466    }
467    val subStageMaps = subStages.flatMap(getUidMapImpl)
468    List((instance.uid, instance)) ++ subStageMaps
469  }
470}
471