1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.r
19
20import org.apache.hadoop.fs.Path
21import org.json4s._
22import org.json4s.JsonDSL._
23import org.json4s.jackson.JsonMethods._
24
25import org.apache.spark.ml.{Pipeline, PipelineModel}
26import org.apache.spark.ml.attribute.AttributeGroup
27import org.apache.spark.ml.feature.RFormula
28import org.apache.spark.ml.r.RWrapperUtils._
29import org.apache.spark.ml.regression._
30import org.apache.spark.ml.util._
31import org.apache.spark.sql._
32
33private[r] class GeneralizedLinearRegressionWrapper private (
34    val pipeline: PipelineModel,
35    val rFeatures: Array[String],
36    val rCoefficients: Array[Double],
37    val rDispersion: Double,
38    val rNullDeviance: Double,
39    val rDeviance: Double,
40    val rResidualDegreeOfFreedomNull: Long,
41    val rResidualDegreeOfFreedom: Long,
42    val rAic: Double,
43    val rNumIterations: Int,
44    val isLoaded: Boolean = false) extends MLWritable {
45
46  private val glm: GeneralizedLinearRegressionModel =
47    pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
48
49  lazy val rDevianceResiduals: DataFrame = glm.summary.residuals()
50
51  lazy val rFamily: String = glm.getFamily
52
53  def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType)
54
55  def transform(dataset: Dataset[_]): DataFrame = {
56    pipeline.transform(dataset).drop(glm.getFeaturesCol)
57  }
58
59  override def write: MLWriter =
60    new GeneralizedLinearRegressionWrapper.GeneralizedLinearRegressionWrapperWriter(this)
61}
62
63private[r] object GeneralizedLinearRegressionWrapper
64  extends MLReadable[GeneralizedLinearRegressionWrapper] {
65
66  def fit(
67      formula: String,
68      data: DataFrame,
69      family: String,
70      link: String,
71      tol: Double,
72      maxIter: Int,
73      weightCol: String,
74      regParam: Double): GeneralizedLinearRegressionWrapper = {
75    val rFormula = new RFormula().setFormula(formula)
76    checkDataColumns(rFormula, data)
77    val rFormulaModel = rFormula.fit(data)
78    // get labels and feature names from output schema
79    val schema = rFormulaModel.transform(data).schema
80    val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
81      .attributes.get
82    val features = featureAttrs.map(_.name.get)
83    // assemble and fit the pipeline
84    val glr = new GeneralizedLinearRegression()
85      .setFamily(family)
86      .setLink(link)
87      .setFitIntercept(rFormula.hasIntercept)
88      .setTol(tol)
89      .setMaxIter(maxIter)
90      .setWeightCol(weightCol)
91      .setRegParam(regParam)
92      .setFeaturesCol(rFormula.getFeaturesCol)
93    val pipeline = new Pipeline()
94      .setStages(Array(rFormulaModel, glr))
95      .fit(data)
96
97    val glm: GeneralizedLinearRegressionModel =
98      pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
99    val summary = glm.summary
100
101    val rFeatures: Array[String] = if (glm.getFitIntercept) {
102      Array("(Intercept)") ++ features
103    } else {
104      features
105    }
106
107    val rCoefficients: Array[Double] = if (summary.isNormalSolver) {
108      val rCoefficientStandardErrors = if (glm.getFitIntercept) {
109        Array(summary.coefficientStandardErrors.last) ++
110          summary.coefficientStandardErrors.dropRight(1)
111      } else {
112        summary.coefficientStandardErrors
113      }
114
115      val rTValues = if (glm.getFitIntercept) {
116        Array(summary.tValues.last) ++ summary.tValues.dropRight(1)
117      } else {
118        summary.tValues
119      }
120
121      val rPValues = if (glm.getFitIntercept) {
122        Array(summary.pValues.last) ++ summary.pValues.dropRight(1)
123      } else {
124        summary.pValues
125      }
126
127      if (glm.getFitIntercept) {
128        Array(glm.intercept) ++ glm.coefficients.toArray ++
129          rCoefficientStandardErrors ++ rTValues ++ rPValues
130      } else {
131        glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
132      }
133    } else {
134      if (glm.getFitIntercept) {
135        Array(glm.intercept) ++ glm.coefficients.toArray
136      } else {
137        glm.coefficients.toArray
138      }
139    }
140
141    val rDispersion: Double = summary.dispersion
142    val rNullDeviance: Double = summary.nullDeviance
143    val rDeviance: Double = summary.deviance
144    val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull
145    val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom
146    val rAic: Double = summary.aic
147    val rNumIterations: Int = summary.numIterations
148
149    new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
150      rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
151      rAic, rNumIterations)
152  }
153
154  override def read: MLReader[GeneralizedLinearRegressionWrapper] =
155    new GeneralizedLinearRegressionWrapperReader
156
157  override def load(path: String): GeneralizedLinearRegressionWrapper = super.load(path)
158
159  class GeneralizedLinearRegressionWrapperWriter(instance: GeneralizedLinearRegressionWrapper)
160    extends MLWriter {
161
162    override protected def saveImpl(path: String): Unit = {
163      val rMetadataPath = new Path(path, "rMetadata").toString
164      val pipelinePath = new Path(path, "pipeline").toString
165
166      val rMetadata = ("class" -> instance.getClass.getName) ~
167        ("rFeatures" -> instance.rFeatures.toSeq) ~
168        ("rCoefficients" -> instance.rCoefficients.toSeq) ~
169        ("rDispersion" -> instance.rDispersion) ~
170        ("rNullDeviance" -> instance.rNullDeviance) ~
171        ("rDeviance" -> instance.rDeviance) ~
172        ("rResidualDegreeOfFreedomNull" -> instance.rResidualDegreeOfFreedomNull) ~
173        ("rResidualDegreeOfFreedom" -> instance.rResidualDegreeOfFreedom) ~
174        ("rAic" -> instance.rAic) ~
175        ("rNumIterations" -> instance.rNumIterations)
176      val rMetadataJson: String = compact(render(rMetadata))
177      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
178
179      instance.pipeline.save(pipelinePath)
180    }
181  }
182
183  class GeneralizedLinearRegressionWrapperReader
184    extends MLReader[GeneralizedLinearRegressionWrapper] {
185
186    override def load(path: String): GeneralizedLinearRegressionWrapper = {
187      implicit val format = DefaultFormats
188      val rMetadataPath = new Path(path, "rMetadata").toString
189      val pipelinePath = new Path(path, "pipeline").toString
190
191      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
192      val rMetadata = parse(rMetadataStr)
193      val rFeatures = (rMetadata \ "rFeatures").extract[Array[String]]
194      val rCoefficients = (rMetadata \ "rCoefficients").extract[Array[Double]]
195      val rDispersion = (rMetadata \ "rDispersion").extract[Double]
196      val rNullDeviance = (rMetadata \ "rNullDeviance").extract[Double]
197      val rDeviance = (rMetadata \ "rDeviance").extract[Double]
198      val rResidualDegreeOfFreedomNull = (rMetadata \ "rResidualDegreeOfFreedomNull").extract[Long]
199      val rResidualDegreeOfFreedom = (rMetadata \ "rResidualDegreeOfFreedom").extract[Long]
200      val rAic = (rMetadata \ "rAic").extract[Double]
201      val rNumIterations = (rMetadata \ "rNumIterations").extract[Int]
202
203      val pipeline = PipelineModel.load(pipelinePath)
204
205      new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
206        rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
207        rAic, rNumIterations, isLoaded = true)
208    }
209  }
210}
211