1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.feature
19
20import org.apache.hadoop.fs.Path
21
22import org.apache.spark.annotation.Since
23import org.apache.spark.ml._
24import org.apache.spark.ml.linalg._
25import org.apache.spark.ml.param._
26import org.apache.spark.ml.param.shared._
27import org.apache.spark.ml.util._
28import org.apache.spark.mllib.feature
29import org.apache.spark.mllib.linalg.{DenseMatrix => OldDenseMatrix, DenseVector => OldDenseVector,
30  Matrices => OldMatrices, Vector => OldVector, Vectors => OldVectors}
31import org.apache.spark.mllib.linalg.MatrixImplicits._
32import org.apache.spark.mllib.linalg.VectorImplicits._
33import org.apache.spark.rdd.RDD
34import org.apache.spark.sql._
35import org.apache.spark.sql.functions._
36import org.apache.spark.sql.types.{StructField, StructType}
37import org.apache.spark.util.VersionUtils.majorVersion
38
39/**
40 * Params for [[PCA]] and [[PCAModel]].
41 */
42private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol {
43
44  /**
45   * The number of principal components.
46   * @group param
47   */
48  final val k: IntParam = new IntParam(this, "k", "the number of principal components (> 0)",
49    ParamValidators.gt(0))
50
51  /** @group getParam */
52  def getK: Int = $(k)
53
54  /** Validates and transforms the input schema. */
55  protected def validateAndTransformSchema(schema: StructType): StructType = {
56    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
57    require(!schema.fieldNames.contains($(outputCol)),
58      s"Output column ${$(outputCol)} already exists.")
59    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
60    StructType(outputFields)
61  }
62
63}
64
65/**
66 * PCA trains a model to project vectors to a lower dimensional space of the top `PCA!.k`
67 * principal components.
68 */
69@Since("1.5.0")
70class PCA @Since("1.5.0") (
71    @Since("1.5.0") override val uid: String)
72  extends Estimator[PCAModel] with PCAParams with DefaultParamsWritable {
73
74  @Since("1.5.0")
75  def this() = this(Identifiable.randomUID("pca"))
76
77  /** @group setParam */
78  @Since("1.5.0")
79  def setInputCol(value: String): this.type = set(inputCol, value)
80
81  /** @group setParam */
82  @Since("1.5.0")
83  def setOutputCol(value: String): this.type = set(outputCol, value)
84
85  /** @group setParam */
86  @Since("1.5.0")
87  def setK(value: Int): this.type = set(k, value)
88
89  /**
90   * Computes a [[PCAModel]] that contains the principal components of the input vectors.
91   */
92  @Since("2.0.0")
93  override def fit(dataset: Dataset[_]): PCAModel = {
94    transformSchema(dataset.schema, logging = true)
95    val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
96      case Row(v: Vector) => OldVectors.fromML(v)
97    }
98    val pca = new feature.PCA(k = $(k))
99    val pcaModel = pca.fit(input)
100    copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this))
101  }
102
103  @Since("1.5.0")
104  override def transformSchema(schema: StructType): StructType = {
105    validateAndTransformSchema(schema)
106  }
107
108  @Since("1.5.0")
109  override def copy(extra: ParamMap): PCA = defaultCopy(extra)
110}
111
112@Since("1.6.0")
113object PCA extends DefaultParamsReadable[PCA] {
114
115  @Since("1.6.0")
116  override def load(path: String): PCA = super.load(path)
117}
118
119/**
120 * Model fitted by [[PCA]]. Transforms vectors to a lower dimensional space.
121 *
122 * @param pc A principal components Matrix. Each column is one principal component.
123 * @param explainedVariance A vector of proportions of variance explained by
124 *                          each principal component.
125 */
126@Since("1.5.0")
127class PCAModel private[ml] (
128    @Since("1.5.0") override val uid: String,
129    @Since("2.0.0") val pc: DenseMatrix,
130    @Since("2.0.0") val explainedVariance: DenseVector)
131  extends Model[PCAModel] with PCAParams with MLWritable {
132
133  import PCAModel._
134
135  /** @group setParam */
136  @Since("1.5.0")
137  def setInputCol(value: String): this.type = set(inputCol, value)
138
139  /** @group setParam */
140  @Since("1.5.0")
141  def setOutputCol(value: String): this.type = set(outputCol, value)
142
143  /**
144   * Transform a vector by computed Principal Components.
145   *
146   * @note Vectors to be transformed must be the same length as the source vectors given
147   * to `PCA.fit()`.
148   */
149  @Since("2.0.0")
150  override def transform(dataset: Dataset[_]): DataFrame = {
151    transformSchema(dataset.schema, logging = true)
152    val pcaModel = new feature.PCAModel($(k),
153      OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix],
154      OldVectors.fromML(explainedVariance).asInstanceOf[OldDenseVector])
155
156    // TODO: Make the transformer natively in ml framework to avoid extra conversion.
157    val transformer: Vector => Vector = v => pcaModel.transform(OldVectors.fromML(v)).asML
158
159    val pcaOp = udf(transformer)
160    dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
161  }
162
163  @Since("1.5.0")
164  override def transformSchema(schema: StructType): StructType = {
165    validateAndTransformSchema(schema)
166  }
167
168  @Since("1.5.0")
169  override def copy(extra: ParamMap): PCAModel = {
170    val copied = new PCAModel(uid, pc, explainedVariance)
171    copyValues(copied, extra).setParent(parent)
172  }
173
174  @Since("1.6.0")
175  override def write: MLWriter = new PCAModelWriter(this)
176}
177
178@Since("1.6.0")
179object PCAModel extends MLReadable[PCAModel] {
180
181  private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
182
183    private case class Data(pc: DenseMatrix, explainedVariance: DenseVector)
184
185    override protected def saveImpl(path: String): Unit = {
186      DefaultParamsWriter.saveMetadata(instance, path, sc)
187      val data = Data(instance.pc, instance.explainedVariance)
188      val dataPath = new Path(path, "data").toString
189      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
190    }
191  }
192
193  private class PCAModelReader extends MLReader[PCAModel] {
194
195    private val className = classOf[PCAModel].getName
196
197    /**
198     * Loads a [[PCAModel]] from data located at the input path. Note that the model includes an
199     * `explainedVariance` member that is not recorded by Spark 1.6 and earlier. A model
200     * can be loaded from such older data but will have an empty vector for
201     * `explainedVariance`.
202     *
203     * @param path path to serialized model data
204     * @return a [[PCAModel]]
205     */
206    override def load(path: String): PCAModel = {
207      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
208
209      val dataPath = new Path(path, "data").toString
210      val model = if (majorVersion(metadata.sparkVersion) >= 2) {
211        val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
212          sparkSession.read.parquet(dataPath)
213            .select("pc", "explainedVariance")
214            .head()
215        new PCAModel(metadata.uid, pc, explainedVariance)
216      } else {
217        // pc field is the old matrix format in Spark <= 1.6
218        // explainedVariance field is not present in Spark <= 1.6
219        val Row(pc: OldDenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
220        new PCAModel(metadata.uid, pc.asML,
221          Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
222      }
223      DefaultParamsReader.getAndSetParams(model, metadata)
224      model
225    }
226  }
227
228  @Since("1.6.0")
229  override def read: MLReader[PCAModel] = new PCAModelReader
230
231  @Since("1.6.0")
232  override def load(path: String): PCAModel = super.load(path)
233}
234