1/* 2 Copyright (c) 2014 by Contributors 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17package ml.dmlc.xgboost4j.scala.spark 18 19import java.io.File 20import java.util.Arrays 21 22import ml.dmlc.xgboost4j.scala.DMatrix 23 24import scala.util.Random 25import org.apache.spark.ml.feature._ 26import org.apache.spark.ml.{Pipeline, PipelineModel} 27import org.apache.spark.sql.functions._ 28import org.scalatest.FunSuite 29 30class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { 31 32 test("test persistence of XGBoostClassifier and XGBoostClassificationModel") { 33 val eval = new EvalError() 34 val trainingDF = buildDataFrame(Classification.train) 35 val testDM = new DMatrix(Classification.test.iterator) 36 37 val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", 38 "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers) 39 val xgbc = new XGBoostClassifier(paramMap) 40 val xgbcPath = new File(tempDir.toFile, "xgbc").getPath 41 xgbc.write.overwrite().save(xgbcPath) 42 val xgbc2 = XGBoostClassifier.load(xgbcPath) 43 val paramMap2 = xgbc2.MLlib2XGBoostParams 44 paramMap.foreach { 45 case (k, v) => assert(v.toString == paramMap2(k).toString) 46 } 47 48 val model = xgbc.fit(trainingDF) 49 val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) 50 assert(evalResults < 0.1) 51 val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath 52 model.write.overwrite.save(xgbcModelPath) 53 val model2 = XGBoostClassificationModel.load(xgbcModelPath) 54 assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) 55 56 assert(model.getEta === model2.getEta) 57 assert(model.getNumRound === model2.getNumRound) 58 assert(model.getRawPredictionCol === model2.getRawPredictionCol) 59 val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM) 60 assert(evalResults === evalResults2) 61 } 62 63 test("test persistence of XGBoostRegressor and XGBoostRegressionModel") { 64 val eval = new EvalError() 65 val trainingDF = buildDataFrame(Regression.train) 66 val testDM = new DMatrix(Regression.test.iterator) 67 68 val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", 69 "objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers) 70 val xgbr = new XGBoostRegressor(paramMap) 71 val xgbrPath = new File(tempDir.toFile, "xgbr").getPath 72 xgbr.write.overwrite().save(xgbrPath) 73 val xgbr2 = XGBoostRegressor.load(xgbrPath) 74 val paramMap2 = xgbr2.MLlib2XGBoostParams 75 paramMap.foreach { 76 case (k, v) => assert(v.toString == paramMap2(k).toString) 77 } 78 79 val model = xgbr.fit(trainingDF) 80 val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) 81 assert(evalResults < 0.1) 82 val xgbrModelPath = new File(tempDir.toFile, "xgbrModel").getPath 83 model.write.overwrite.save(xgbrModelPath) 84 val model2 = XGBoostRegressionModel.load(xgbrModelPath) 85 assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) 86 87 assert(model.getEta === model2.getEta) 88 assert(model.getNumRound === model2.getNumRound) 89 assert(model.getPredictionCol === model2.getPredictionCol) 90 val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM) 91 assert(evalResults === evalResults2) 92 } 93 94 test("test persistence of MLlib pipeline with XGBoostClassificationModel") { 95 96 val r = new Random(0) 97 // maybe move to shared context, but requires session to import implicits 98 val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). 99 toDF("feature", "label") 100 101 val assembler = new VectorAssembler() 102 .setInputCols(df.columns.filter(!_.contains("label"))) 103 .setOutputCol("features") 104 105 val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", 106 "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers) 107 val xgb = new XGBoostClassifier(paramMap) 108 109 // Construct MLlib pipeline, save and load 110 val pipeline = new Pipeline().setStages(Array(assembler, xgb)) 111 val pipePath = new File(tempDir.toFile, "pipeline").getPath 112 pipeline.write.overwrite().save(pipePath) 113 val pipeline2 = Pipeline.read.load(pipePath) 114 val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier] 115 val paramMap2 = xgb2.MLlib2XGBoostParams 116 paramMap.foreach { 117 case (k, v) => assert(v.toString == paramMap2(k).toString) 118 } 119 120 // Model training, save and load 121 val pipeModel = pipeline.fit(df) 122 val pipeModelPath = new File(tempDir.toFile, "pipelineModel").getPath 123 pipeModel.write.overwrite.save(pipeModelPath) 124 val pipeModel2 = PipelineModel.load(pipeModelPath) 125 126 val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel] 127 val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel] 128 129 assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray)) 130 131 assert(xgbModel.getEta === xgbModel2.getEta) 132 assert(xgbModel.getNumRound === xgbModel2.getNumRound) 133 assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol) 134 } 135 136 test("cross-version model loading (0.82)") { 137 val modelPath = getClass.getResource("/model/0.82/model").getPath 138 val model = XGBoostClassificationModel.read.load(modelPath) 139 val r = new Random(0) 140 var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). 141 toDF("feature", "label") 142 // 0.82/model was trained with 251 features. and transform will throw exception 143 // if feature size of data is not equal to 251 144 for (x <- 1 to 250) { 145 df = df.withColumn(s"feature_${x}", lit(1)) 146 } 147 val assembler = new VectorAssembler() 148 .setInputCols(df.columns.filter(!_.contains("label"))) 149 .setOutputCol("features") 150 df = assembler.transform(df) 151 for (x <- 1 to 250) { 152 df = df.drop(s"feature_${x}") 153 } 154 model.transform(df).show() 155 } 156} 157 158