1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.r 19 20import org.apache.hadoop.fs.Path 21import org.json4s._ 22import org.json4s.JsonDSL._ 23import org.json4s.jackson.JsonMethods._ 24 25import org.apache.spark.ml.{Pipeline, PipelineModel} 26import org.apache.spark.ml.attribute.AttributeGroup 27import org.apache.spark.ml.feature.RFormula 28import org.apache.spark.ml.r.RWrapperUtils._ 29import org.apache.spark.ml.regression._ 30import org.apache.spark.ml.util._ 31import org.apache.spark.sql._ 32 33private[r] class GeneralizedLinearRegressionWrapper private ( 34 val pipeline: PipelineModel, 35 val rFeatures: Array[String], 36 val rCoefficients: Array[Double], 37 val rDispersion: Double, 38 val rNullDeviance: Double, 39 val rDeviance: Double, 40 val rResidualDegreeOfFreedomNull: Long, 41 val rResidualDegreeOfFreedom: Long, 42 val rAic: Double, 43 val rNumIterations: Int, 44 val isLoaded: Boolean = false) extends MLWritable { 45 46 private val glm: GeneralizedLinearRegressionModel = 47 pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] 48 49 lazy val rDevianceResiduals: DataFrame = glm.summary.residuals() 50 51 lazy val rFamily: String = glm.getFamily 52 53 def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType) 54 55 def transform(dataset: Dataset[_]): DataFrame = { 56 pipeline.transform(dataset).drop(glm.getFeaturesCol) 57 } 58 59 override def write: MLWriter = 60 new GeneralizedLinearRegressionWrapper.GeneralizedLinearRegressionWrapperWriter(this) 61} 62 63private[r] object GeneralizedLinearRegressionWrapper 64 extends MLReadable[GeneralizedLinearRegressionWrapper] { 65 66 def fit( 67 formula: String, 68 data: DataFrame, 69 family: String, 70 link: String, 71 tol: Double, 72 maxIter: Int, 73 weightCol: String, 74 regParam: Double): GeneralizedLinearRegressionWrapper = { 75 val rFormula = new RFormula().setFormula(formula) 76 checkDataColumns(rFormula, data) 77 val rFormulaModel = rFormula.fit(data) 78 // get labels and feature names from output schema 79 val schema = rFormulaModel.transform(data).schema 80 val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) 81 .attributes.get 82 val features = featureAttrs.map(_.name.get) 83 // assemble and fit the pipeline 84 val glr = new GeneralizedLinearRegression() 85 .setFamily(family) 86 .setLink(link) 87 .setFitIntercept(rFormula.hasIntercept) 88 .setTol(tol) 89 .setMaxIter(maxIter) 90 .setWeightCol(weightCol) 91 .setRegParam(regParam) 92 .setFeaturesCol(rFormula.getFeaturesCol) 93 val pipeline = new Pipeline() 94 .setStages(Array(rFormulaModel, glr)) 95 .fit(data) 96 97 val glm: GeneralizedLinearRegressionModel = 98 pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] 99 val summary = glm.summary 100 101 val rFeatures: Array[String] = if (glm.getFitIntercept) { 102 Array("(Intercept)") ++ features 103 } else { 104 features 105 } 106 107 val rCoefficients: Array[Double] = if (summary.isNormalSolver) { 108 val rCoefficientStandardErrors = if (glm.getFitIntercept) { 109 Array(summary.coefficientStandardErrors.last) ++ 110 summary.coefficientStandardErrors.dropRight(1) 111 } else { 112 summary.coefficientStandardErrors 113 } 114 115 val rTValues = if (glm.getFitIntercept) { 116 Array(summary.tValues.last) ++ summary.tValues.dropRight(1) 117 } else { 118 summary.tValues 119 } 120 121 val rPValues = if (glm.getFitIntercept) { 122 Array(summary.pValues.last) ++ summary.pValues.dropRight(1) 123 } else { 124 summary.pValues 125 } 126 127 if (glm.getFitIntercept) { 128 Array(glm.intercept) ++ glm.coefficients.toArray ++ 129 rCoefficientStandardErrors ++ rTValues ++ rPValues 130 } else { 131 glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues 132 } 133 } else { 134 if (glm.getFitIntercept) { 135 Array(glm.intercept) ++ glm.coefficients.toArray 136 } else { 137 glm.coefficients.toArray 138 } 139 } 140 141 val rDispersion: Double = summary.dispersion 142 val rNullDeviance: Double = summary.nullDeviance 143 val rDeviance: Double = summary.deviance 144 val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull 145 val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom 146 val rAic: Double = summary.aic 147 val rNumIterations: Int = summary.numIterations 148 149 new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, 150 rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom, 151 rAic, rNumIterations) 152 } 153 154 override def read: MLReader[GeneralizedLinearRegressionWrapper] = 155 new GeneralizedLinearRegressionWrapperReader 156 157 override def load(path: String): GeneralizedLinearRegressionWrapper = super.load(path) 158 159 class GeneralizedLinearRegressionWrapperWriter(instance: GeneralizedLinearRegressionWrapper) 160 extends MLWriter { 161 162 override protected def saveImpl(path: String): Unit = { 163 val rMetadataPath = new Path(path, "rMetadata").toString 164 val pipelinePath = new Path(path, "pipeline").toString 165 166 val rMetadata = ("class" -> instance.getClass.getName) ~ 167 ("rFeatures" -> instance.rFeatures.toSeq) ~ 168 ("rCoefficients" -> instance.rCoefficients.toSeq) ~ 169 ("rDispersion" -> instance.rDispersion) ~ 170 ("rNullDeviance" -> instance.rNullDeviance) ~ 171 ("rDeviance" -> instance.rDeviance) ~ 172 ("rResidualDegreeOfFreedomNull" -> instance.rResidualDegreeOfFreedomNull) ~ 173 ("rResidualDegreeOfFreedom" -> instance.rResidualDegreeOfFreedom) ~ 174 ("rAic" -> instance.rAic) ~ 175 ("rNumIterations" -> instance.rNumIterations) 176 val rMetadataJson: String = compact(render(rMetadata)) 177 sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) 178 179 instance.pipeline.save(pipelinePath) 180 } 181 } 182 183 class GeneralizedLinearRegressionWrapperReader 184 extends MLReader[GeneralizedLinearRegressionWrapper] { 185 186 override def load(path: String): GeneralizedLinearRegressionWrapper = { 187 implicit val format = DefaultFormats 188 val rMetadataPath = new Path(path, "rMetadata").toString 189 val pipelinePath = new Path(path, "pipeline").toString 190 191 val rMetadataStr = sc.textFile(rMetadataPath, 1).first() 192 val rMetadata = parse(rMetadataStr) 193 val rFeatures = (rMetadata \ "rFeatures").extract[Array[String]] 194 val rCoefficients = (rMetadata \ "rCoefficients").extract[Array[Double]] 195 val rDispersion = (rMetadata \ "rDispersion").extract[Double] 196 val rNullDeviance = (rMetadata \ "rNullDeviance").extract[Double] 197 val rDeviance = (rMetadata \ "rDeviance").extract[Double] 198 val rResidualDegreeOfFreedomNull = (rMetadata \ "rResidualDegreeOfFreedomNull").extract[Long] 199 val rResidualDegreeOfFreedom = (rMetadata \ "rResidualDegreeOfFreedom").extract[Long] 200 val rAic = (rMetadata \ "rAic").extract[Double] 201 val rNumIterations = (rMetadata \ "rNumIterations").extract[Int] 202 203 val pipeline = PipelineModel.load(pipelinePath) 204 205 new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, 206 rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom, 207 rAic, rNumIterations, isLoaded = true) 208 } 209 } 210} 211