1/*
2 Copyright (c) 2014 by Contributors
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17package ml.dmlc.xgboost4j.scala.spark
18
19import java.io.File
20import java.util.Arrays
21
22import ml.dmlc.xgboost4j.scala.DMatrix
23
24import scala.util.Random
25import org.apache.spark.ml.feature._
26import org.apache.spark.ml.{Pipeline, PipelineModel}
27import org.apache.spark.sql.functions._
28import org.scalatest.FunSuite
29
30class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
31
32  test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
33    val eval = new EvalError()
34    val trainingDF = buildDataFrame(Classification.train)
35    val testDM = new DMatrix(Classification.test.iterator)
36
37    val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
38      "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
39    val xgbc = new XGBoostClassifier(paramMap)
40    val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
41    xgbc.write.overwrite().save(xgbcPath)
42    val xgbc2 = XGBoostClassifier.load(xgbcPath)
43    val paramMap2 = xgbc2.MLlib2XGBoostParams
44    paramMap.foreach {
45      case (k, v) => assert(v.toString == paramMap2(k).toString)
46    }
47
48    val model = xgbc.fit(trainingDF)
49    val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
50    assert(evalResults < 0.1)
51    val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
52    model.write.overwrite.save(xgbcModelPath)
53    val model2 = XGBoostClassificationModel.load(xgbcModelPath)
54    assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
55
56    assert(model.getEta === model2.getEta)
57    assert(model.getNumRound === model2.getNumRound)
58    assert(model.getRawPredictionCol === model2.getRawPredictionCol)
59    val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
60    assert(evalResults === evalResults2)
61  }
62
63  test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
64    val eval = new EvalError()
65    val trainingDF = buildDataFrame(Regression.train)
66    val testDM = new DMatrix(Regression.test.iterator)
67
68    val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
69      "objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers)
70    val xgbr = new XGBoostRegressor(paramMap)
71    val xgbrPath = new File(tempDir.toFile, "xgbr").getPath
72    xgbr.write.overwrite().save(xgbrPath)
73    val xgbr2 = XGBoostRegressor.load(xgbrPath)
74    val paramMap2 = xgbr2.MLlib2XGBoostParams
75    paramMap.foreach {
76      case (k, v) => assert(v.toString == paramMap2(k).toString)
77    }
78
79    val model = xgbr.fit(trainingDF)
80    val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
81    assert(evalResults < 0.1)
82    val xgbrModelPath = new File(tempDir.toFile, "xgbrModel").getPath
83    model.write.overwrite.save(xgbrModelPath)
84    val model2 = XGBoostRegressionModel.load(xgbrModelPath)
85    assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
86
87    assert(model.getEta === model2.getEta)
88    assert(model.getNumRound === model2.getNumRound)
89    assert(model.getPredictionCol === model2.getPredictionCol)
90    val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
91    assert(evalResults === evalResults2)
92  }
93
94  test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
95
96    val r = new Random(0)
97    // maybe move to shared context, but requires session to import implicits
98    val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
99      toDF("feature", "label")
100
101    val assembler = new VectorAssembler()
102      .setInputCols(df.columns.filter(!_.contains("label")))
103      .setOutputCol("features")
104
105    val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
106      "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
107    val xgb = new XGBoostClassifier(paramMap)
108
109    // Construct MLlib pipeline, save and load
110    val pipeline = new Pipeline().setStages(Array(assembler, xgb))
111    val pipePath = new File(tempDir.toFile, "pipeline").getPath
112    pipeline.write.overwrite().save(pipePath)
113    val pipeline2 = Pipeline.read.load(pipePath)
114    val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
115    val paramMap2 = xgb2.MLlib2XGBoostParams
116    paramMap.foreach {
117      case (k, v) => assert(v.toString == paramMap2(k).toString)
118    }
119
120    // Model training, save and load
121    val pipeModel = pipeline.fit(df)
122    val pipeModelPath = new File(tempDir.toFile, "pipelineModel").getPath
123    pipeModel.write.overwrite.save(pipeModelPath)
124    val pipeModel2 = PipelineModel.load(pipeModelPath)
125
126    val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
127    val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
128
129    assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
130
131    assert(xgbModel.getEta === xgbModel2.getEta)
132    assert(xgbModel.getNumRound === xgbModel2.getNumRound)
133    assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
134  }
135
136  test("cross-version model loading (0.82)") {
137    val modelPath = getClass.getResource("/model/0.82/model").getPath
138    val model = XGBoostClassificationModel.read.load(modelPath)
139    val r = new Random(0)
140    var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
141      toDF("feature", "label")
142    // 0.82/model was trained with 251 features. and transform will throw exception
143    // if feature size of data is not equal to 251
144    for (x <- 1 to 250) {
145      df = df.withColumn(s"feature_${x}", lit(1))
146    }
147    val assembler = new VectorAssembler()
148      .setInputCols(df.columns.filter(!_.contains("label")))
149      .setOutputCol("features")
150    df = assembler.transform(df)
151    for (x <- 1 to 250) {
152      df = df.drop(s"feature_${x}")
153    }
154    model.transform(df).show()
155  }
156}
157
158