1/* 2 Copyright (c) 2014 by Contributors 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17package ml.dmlc.xgboost4j.scala.spark.params 18 19import org.apache.hadoop.fs.Path 20 21import org.apache.spark.SparkContext 22import org.apache.spark.ml.param.{ParamPair, Params} 23import org.json4s.jackson.JsonMethods._ 24import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue} 25 26import JsonDSLXGBoost._ 27 28// This originates from apache-spark DefaultPramsWriter copy paste 29private[spark] object DefaultXGBoostParamsWriter { 30 31 /** 32 * Saves metadata + Params to: path + "/metadata" 33 * - class 34 * - timestamp 35 * - sparkVersion 36 * - uid 37 * - paramMap 38 * - (optionally, extra metadata) 39 * 40 * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. 41 * @param paramMap If given, this is saved in the "paramMap" field. 42 * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using 43 * [[org.apache.spark.ml.param.Param.jsonEncode()]]. 44 */ 45 def saveMetadata( 46 instance: Params, 47 path: String, 48 sc: SparkContext, 49 extraMetadata: Option[JObject] = None, 50 paramMap: Option[JValue] = None): Unit = { 51 52 val metadataPath = new Path(path, "metadata").toString 53 val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) 54 sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) 55 } 56 57 /** 58 * Helper for [[saveMetadata()]] which extracts the JSON to save. 59 * This is useful for ensemble models which need to save metadata for many sub-models. 60 * 61 * @see [[saveMetadata()]] for details on what this includes. 62 */ 63 def getMetadataToSave( 64 instance: Params, 65 sc: SparkContext, 66 extraMetadata: Option[JObject] = None, 67 paramMap: Option[JValue] = None): String = { 68 val uid = instance.uid 69 val cls = instance.getClass.getName 70 val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] 71 val jsonParams = paramMap.getOrElse(render(params.filter{ 72 case ParamPair(p, _) => p != null 73 }.map { 74 case ParamPair(p, v) => 75 p.name -> parse(p.jsonEncode(v)) 76 }.toList)) 77 val basicMetadata = ("class" -> cls) ~ 78 ("timestamp" -> System.currentTimeMillis()) ~ 79 ("sparkVersion" -> sc.version) ~ 80 ("uid" -> uid) ~ 81 ("paramMap" -> jsonParams) 82 val metadata = extraMetadata match { 83 case Some(jObject) => 84 basicMetadata ~ jObject 85 case None => 86 basicMetadata 87 } 88 val metadataJson: String = compact(render(metadata)) 89 metadataJson 90 } 91} 92 93// Fix json4s bin-incompatible issue. 94// This originates from org.json4s.JsonDSL of 3.6.6 95object JsonDSLXGBoost { 96 97 implicit def seq2jvalue[A](s: Iterable[A])(implicit ev: A => JValue): JArray = 98 JArray(s.toList.map(ev)) 99 100 implicit def map2jvalue[A](m: Map[String, A])(implicit ev: A => JValue): JObject = 101 JObject(m.toList.map { case (k, v) => JField(k, ev(v)) }) 102 103 implicit def option2jvalue[A](opt: Option[A])(implicit ev: A => JValue): JValue = opt match { 104 case Some(x) => ev(x) 105 case None => JNothing 106 } 107 108 implicit def short2jvalue(x: Short): JValue = JInt(x) 109 implicit def byte2jvalue(x: Byte): JValue = JInt(x) 110 implicit def char2jvalue(x: Char): JValue = JInt(x) 111 implicit def int2jvalue(x: Int): JValue = JInt(x) 112 implicit def long2jvalue(x: Long): JValue = JInt(x) 113 implicit def bigint2jvalue(x: BigInt): JValue = JInt(x) 114 implicit def double2jvalue(x: Double): JValue = JDouble(x) 115 implicit def float2jvalue(x: Float): JValue = JDouble(x.toDouble) 116 implicit def bigdecimal2jvalue(x: BigDecimal): JValue = JDouble(x.doubleValue) 117 implicit def boolean2jvalue(x: Boolean): JValue = JBool(x) 118 implicit def string2jvalue(x: String): JValue = JString(x) 119 120 implicit def symbol2jvalue(x: Symbol): JString = JString(x.name) 121 implicit def pair2jvalue[A](t: (String, A))(implicit ev: A => JValue): JObject = 122 JObject(List(JField(t._1, ev(t._2)))) 123 implicit def list2jvalue(l: List[JField]): JObject = JObject(l) 124 implicit def jobject2assoc(o: JObject): JsonListAssoc = new JsonListAssoc(o.obj) 125 implicit def pair2Assoc[A](t: (String, A))(implicit ev: A => JValue): JsonAssoc[A] = 126 new JsonAssoc(t) 127} 128 129final class JsonAssoc[A](private val left: (String, A)) extends AnyVal { 130 def ~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = { 131 val l: JValue = ev1(left._2) 132 val r: JValue = ev2(right._2) 133 JObject(JField(left._1, l) :: JField(right._1, r) :: Nil) 134 } 135 136 def ~(right: JObject)(implicit ev: A => JValue): JObject = { 137 val l: JValue = ev(left._2) 138 JObject(JField(left._1, l) :: right.obj) 139 } 140 def ~~[B](right: (String, B))(implicit ev1: A => JValue, ev2: B => JValue): JObject = 141 this.~(right) 142 def ~~(right: JObject)(implicit ev: A => JValue): JObject = this.~(right) 143} 144 145final class JsonListAssoc(private val left: List[JField]) extends AnyVal { 146 def ~(right: (String, JValue)): JObject = JObject(left ::: List(JField(right._1, right._2))) 147 def ~(right: JObject): JObject = JObject(left ::: right.obj) 148 def ~~(right: (String, JValue)): JObject = this.~(right) 149 def ~~(right: JObject): JObject = this.~(right) 150} 151