1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.feature 19 20import org.apache.hadoop.fs.Path 21 22import org.apache.spark.annotation.Since 23import org.apache.spark.ml._ 24import org.apache.spark.ml.linalg._ 25import org.apache.spark.ml.param._ 26import org.apache.spark.ml.param.shared._ 27import org.apache.spark.ml.util._ 28import org.apache.spark.mllib.feature 29import org.apache.spark.mllib.linalg.{DenseMatrix => OldDenseMatrix, DenseVector => OldDenseVector, 30 Matrices => OldMatrices, Vector => OldVector, Vectors => OldVectors} 31import org.apache.spark.mllib.linalg.MatrixImplicits._ 32import org.apache.spark.mllib.linalg.VectorImplicits._ 33import org.apache.spark.rdd.RDD 34import org.apache.spark.sql._ 35import org.apache.spark.sql.functions._ 36import org.apache.spark.sql.types.{StructField, StructType} 37import org.apache.spark.util.VersionUtils.majorVersion 38 39/** 40 * Params for [[PCA]] and [[PCAModel]]. 41 */ 42private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol { 43 44 /** 45 * The number of principal components. 46 * @group param 47 */ 48 final val k: IntParam = new IntParam(this, "k", "the number of principal components (> 0)", 49 ParamValidators.gt(0)) 50 51 /** @group getParam */ 52 def getK: Int = $(k) 53 54 /** Validates and transforms the input schema. */ 55 protected def validateAndTransformSchema(schema: StructType): StructType = { 56 SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) 57 require(!schema.fieldNames.contains($(outputCol)), 58 s"Output column ${$(outputCol)} already exists.") 59 val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) 60 StructType(outputFields) 61 } 62 63} 64 65/** 66 * PCA trains a model to project vectors to a lower dimensional space of the top `PCA!.k` 67 * principal components. 68 */ 69@Since("1.5.0") 70class PCA @Since("1.5.0") ( 71 @Since("1.5.0") override val uid: String) 72 extends Estimator[PCAModel] with PCAParams with DefaultParamsWritable { 73 74 @Since("1.5.0") 75 def this() = this(Identifiable.randomUID("pca")) 76 77 /** @group setParam */ 78 @Since("1.5.0") 79 def setInputCol(value: String): this.type = set(inputCol, value) 80 81 /** @group setParam */ 82 @Since("1.5.0") 83 def setOutputCol(value: String): this.type = set(outputCol, value) 84 85 /** @group setParam */ 86 @Since("1.5.0") 87 def setK(value: Int): this.type = set(k, value) 88 89 /** 90 * Computes a [[PCAModel]] that contains the principal components of the input vectors. 91 */ 92 @Since("2.0.0") 93 override def fit(dataset: Dataset[_]): PCAModel = { 94 transformSchema(dataset.schema, logging = true) 95 val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { 96 case Row(v: Vector) => OldVectors.fromML(v) 97 } 98 val pca = new feature.PCA(k = $(k)) 99 val pcaModel = pca.fit(input) 100 copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) 101 } 102 103 @Since("1.5.0") 104 override def transformSchema(schema: StructType): StructType = { 105 validateAndTransformSchema(schema) 106 } 107 108 @Since("1.5.0") 109 override def copy(extra: ParamMap): PCA = defaultCopy(extra) 110} 111 112@Since("1.6.0") 113object PCA extends DefaultParamsReadable[PCA] { 114 115 @Since("1.6.0") 116 override def load(path: String): PCA = super.load(path) 117} 118 119/** 120 * Model fitted by [[PCA]]. Transforms vectors to a lower dimensional space. 121 * 122 * @param pc A principal components Matrix. Each column is one principal component. 123 * @param explainedVariance A vector of proportions of variance explained by 124 * each principal component. 125 */ 126@Since("1.5.0") 127class PCAModel private[ml] ( 128 @Since("1.5.0") override val uid: String, 129 @Since("2.0.0") val pc: DenseMatrix, 130 @Since("2.0.0") val explainedVariance: DenseVector) 131 extends Model[PCAModel] with PCAParams with MLWritable { 132 133 import PCAModel._ 134 135 /** @group setParam */ 136 @Since("1.5.0") 137 def setInputCol(value: String): this.type = set(inputCol, value) 138 139 /** @group setParam */ 140 @Since("1.5.0") 141 def setOutputCol(value: String): this.type = set(outputCol, value) 142 143 /** 144 * Transform a vector by computed Principal Components. 145 * 146 * @note Vectors to be transformed must be the same length as the source vectors given 147 * to `PCA.fit()`. 148 */ 149 @Since("2.0.0") 150 override def transform(dataset: Dataset[_]): DataFrame = { 151 transformSchema(dataset.schema, logging = true) 152 val pcaModel = new feature.PCAModel($(k), 153 OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix], 154 OldVectors.fromML(explainedVariance).asInstanceOf[OldDenseVector]) 155 156 // TODO: Make the transformer natively in ml framework to avoid extra conversion. 157 val transformer: Vector => Vector = v => pcaModel.transform(OldVectors.fromML(v)).asML 158 159 val pcaOp = udf(transformer) 160 dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) 161 } 162 163 @Since("1.5.0") 164 override def transformSchema(schema: StructType): StructType = { 165 validateAndTransformSchema(schema) 166 } 167 168 @Since("1.5.0") 169 override def copy(extra: ParamMap): PCAModel = { 170 val copied = new PCAModel(uid, pc, explainedVariance) 171 copyValues(copied, extra).setParent(parent) 172 } 173 174 @Since("1.6.0") 175 override def write: MLWriter = new PCAModelWriter(this) 176} 177 178@Since("1.6.0") 179object PCAModel extends MLReadable[PCAModel] { 180 181 private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { 182 183 private case class Data(pc: DenseMatrix, explainedVariance: DenseVector) 184 185 override protected def saveImpl(path: String): Unit = { 186 DefaultParamsWriter.saveMetadata(instance, path, sc) 187 val data = Data(instance.pc, instance.explainedVariance) 188 val dataPath = new Path(path, "data").toString 189 sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) 190 } 191 } 192 193 private class PCAModelReader extends MLReader[PCAModel] { 194 195 private val className = classOf[PCAModel].getName 196 197 /** 198 * Loads a [[PCAModel]] from data located at the input path. Note that the model includes an 199 * `explainedVariance` member that is not recorded by Spark 1.6 and earlier. A model 200 * can be loaded from such older data but will have an empty vector for 201 * `explainedVariance`. 202 * 203 * @param path path to serialized model data 204 * @return a [[PCAModel]] 205 */ 206 override def load(path: String): PCAModel = { 207 val metadata = DefaultParamsReader.loadMetadata(path, sc, className) 208 209 val dataPath = new Path(path, "data").toString 210 val model = if (majorVersion(metadata.sparkVersion) >= 2) { 211 val Row(pc: DenseMatrix, explainedVariance: DenseVector) = 212 sparkSession.read.parquet(dataPath) 213 .select("pc", "explainedVariance") 214 .head() 215 new PCAModel(metadata.uid, pc, explainedVariance) 216 } else { 217 // pc field is the old matrix format in Spark <= 1.6 218 // explainedVariance field is not present in Spark <= 1.6 219 val Row(pc: OldDenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head() 220 new PCAModel(metadata.uid, pc.asML, 221 Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) 222 } 223 DefaultParamsReader.getAndSetParams(model, metadata) 224 model 225 } 226 } 227 228 @Since("1.6.0") 229 override def read: MLReader[PCAModel] = new PCAModelReader 230 231 @Since("1.6.0") 232 override def load(path: String): PCAModel = super.load(path) 233} 234