1/*
2 Copyright (c) 2014 by Contributors
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17package ml.dmlc.xgboost4j.scala.spark.params
18
19import org.apache.hadoop.fs.Path
20
21import org.apache.spark.SparkContext
22import org.apache.spark.ml.param.{ParamPair, Params}
23import org.json4s.jackson.JsonMethods._
24import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue}
25
26import JsonDSLXGBoost._
27
28// This originates from apache-spark DefaultPramsWriter copy paste
29private[spark] object DefaultXGBoostParamsWriter {
30
31  /**
32   * Saves metadata + Params to: path + "/metadata"
33   *  - class
34   *  - timestamp
35   *  - sparkVersion
36   *  - uid
37   *  - paramMap
38   *  - (optionally, extra metadata)
39   *
40   * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
41   * @param paramMap      If given, this is saved in the "paramMap" field.
42   *                      Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
43   *                      [[org.apache.spark.ml.param.Param.jsonEncode()]].
44   */
45  def saveMetadata(
46    instance: Params,
47    path: String,
48    sc: SparkContext,
49    extraMetadata: Option[JObject] = None,
50    paramMap: Option[JValue] = None): Unit = {
51
52    val metadataPath = new Path(path, "metadata").toString
53    val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
54    sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
55  }
56
57  /**
58   * Helper for [[saveMetadata()]] which extracts the JSON to save.
59   * This is useful for ensemble models which need to save metadata for many sub-models.
60   *
61   * @see [[saveMetadata()]] for details on what this includes.
62   */
63  def getMetadataToSave(
64    instance: Params,
65    sc: SparkContext,
66    extraMetadata: Option[JObject] = None,
67    paramMap: Option[JValue] = None): String = {
68    val uid = instance.uid
69    val cls = instance.getClass.getName
70    val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
71    val jsonParams = paramMap.getOrElse(render(params.filter{
72      case ParamPair(p, _) => p != null
73    }.map {
74      case ParamPair(p, v) =>
75        p.name -> parse(p.jsonEncode(v))
76    }.toList))
77    val basicMetadata = ("class" -> cls) ~
78      ("timestamp" -> System.currentTimeMillis()) ~
79      ("sparkVersion" -> sc.version) ~
80      ("uid" -> uid) ~
81      ("paramMap" -> jsonParams)
82    val metadata = extraMetadata match {
83      case Some(jObject) =>
84        basicMetadata ~ jObject
85      case None =>
86        basicMetadata
87    }
88    val metadataJson: String = compact(render(metadata))
89    metadataJson
90  }
91}
92
93// Fix json4s bin-incompatible issue.
94// This originates from org.json4s.JsonDSL of 3.6.6
95object JsonDSLXGBoost {
96
97  implicit def seq2jvalue[A](s: Iterable[A])(implicit ev: A => JValue): JArray =
98    JArray(s.toList.map(ev))
99
100  implicit def map2jvalue[A](m: Map[String, A])(implicit ev: A => JValue): JObject =
101    JObject(m.toList.map { case (k, v) => JField(k, ev(v)) })
102
103  implicit def option2jvalue[A](opt: Option[A])(implicit ev: A => JValue): JValue = opt match {
104    case Some(x) => ev(x)
105    case None => JNothing
106  }
107
108  implicit def short2jvalue(x: Short): JValue = JInt(x)
109  implicit def byte2jvalue(x: Byte): JValue = JInt(x)
110  implicit def char2jvalue(x: Char): JValue = JInt(x)
111  implicit def int2jvalue(x: Int): JValue = JInt(x)
112  implicit def long2jvalue(x: Long): JValue = JInt(x)
113  implicit def bigint2jvalue(x: BigInt): JValue = JInt(x)
114  implicit def double2jvalue(x: Double): JValue = JDouble(x)
115  implicit def float2jvalue(x: Float): JValue = JDouble(x.toDouble)
116  implicit def bigdecimal2jvalue(x: BigDecimal): JValue = JDouble(x.doubleValue)
117  implicit def boolean2jvalue(x: Boolean): JValue = JBool(x)
118  implicit def string2jvalue(x: String): JValue = JString(x)
119
120  implicit def symbol2jvalue(x: Symbol): JString = JString(x.name)
121  implicit def pair2jvalue[A](t: (String, A))(implicit ev: A => JValue): JObject =
122    JObject(List(JField(t._1, ev(t._2))))
123  implicit def list2jvalue(l: List[JField]): JObject = JObject(l)
124  implicit def jobject2assoc(o: JObject): JsonListAssoc = new JsonListAssoc(o.obj)
125  implicit def pair2Assoc[A](t: (String, A))(implicit ev: A => JValue): JsonAssoc[A] =
126    new JsonAssoc(t)
127}
128
129final class JsonAssoc[A](private val left: (String, A)) extends AnyVal {
130  def ~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = {
131    val l: JValue = ev1(left._2)
132    val r: JValue = ev2(right._2)
133    JObject(JField(left._1, l) :: JField(right._1, r) :: Nil)
134  }
135
136  def ~(right: JObject)(implicit ev: A => JValue): JObject = {
137    val l: JValue = ev(left._2)
138    JObject(JField(left._1, l) :: right.obj)
139  }
140  def ~~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject =
141    this.~(right)
142  def ~~(right: JObject)(implicit ev: A => JValue): JObject = this.~(right)
143}
144
145final class JsonListAssoc(private val left: List[JField]) extends AnyVal {
146  def ~(right: (String, JValue)): JObject = JObject(left ::: List(JField(right._1, right._2)))
147  def ~(right: JObject): JObject = JObject(left ::: right.obj)
148  def ~~(right: (String, JValue)): JObject = this.~(right)
149  def ~~(right: JObject): JObject = this.~(right)
150}
151