1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.feature 19 20import org.apache.hadoop.fs.Path 21 22import org.apache.spark.annotation.Since 23import org.apache.spark.ml._ 24import org.apache.spark.ml.linalg.{Vector, VectorUDT} 25import org.apache.spark.ml.param._ 26import org.apache.spark.ml.param.shared._ 27import org.apache.spark.ml.util._ 28import org.apache.spark.mllib.feature 29import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} 30import org.apache.spark.mllib.util.MLUtils 31import org.apache.spark.rdd.RDD 32import org.apache.spark.sql._ 33import org.apache.spark.sql.functions._ 34import org.apache.spark.sql.types.StructType 35 36/** 37 * Params for [[IDF]] and [[IDFModel]]. 38 */ 39private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol { 40 41 /** 42 * The minimum number of documents in which a term should appear. 43 * Default: 0 44 * @group param 45 */ 46 final val minDocFreq = new IntParam( 47 this, "minDocFreq", "minimum number of documents in which a term should appear for filtering" + 48 " (>= 0)", ParamValidators.gtEq(0)) 49 50 setDefault(minDocFreq -> 0) 51 52 /** @group getParam */ 53 def getMinDocFreq: Int = $(minDocFreq) 54 55 /** 56 * Validate and transform the input schema. 57 */ 58 protected def validateAndTransformSchema(schema: StructType): StructType = { 59 SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) 60 SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) 61 } 62} 63 64/** 65 * Compute the Inverse Document Frequency (IDF) given a collection of documents. 66 */ 67@Since("1.4.0") 68final class IDF @Since("1.4.0") (@Since("1.4.0") override val uid: String) 69 extends Estimator[IDFModel] with IDFBase with DefaultParamsWritable { 70 71 @Since("1.4.0") 72 def this() = this(Identifiable.randomUID("idf")) 73 74 /** @group setParam */ 75 @Since("1.4.0") 76 def setInputCol(value: String): this.type = set(inputCol, value) 77 78 /** @group setParam */ 79 @Since("1.4.0") 80 def setOutputCol(value: String): this.type = set(outputCol, value) 81 82 /** @group setParam */ 83 @Since("1.4.0") 84 def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) 85 86 @Since("2.0.0") 87 override def fit(dataset: Dataset[_]): IDFModel = { 88 transformSchema(dataset.schema, logging = true) 89 val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { 90 case Row(v: Vector) => OldVectors.fromML(v) 91 } 92 val idf = new feature.IDF($(minDocFreq)).fit(input) 93 copyValues(new IDFModel(uid, idf).setParent(this)) 94 } 95 96 @Since("1.4.0") 97 override def transformSchema(schema: StructType): StructType = { 98 validateAndTransformSchema(schema) 99 } 100 101 @Since("1.4.1") 102 override def copy(extra: ParamMap): IDF = defaultCopy(extra) 103} 104 105@Since("1.6.0") 106object IDF extends DefaultParamsReadable[IDF] { 107 108 @Since("1.6.0") 109 override def load(path: String): IDF = super.load(path) 110} 111 112/** 113 * Model fitted by [[IDF]]. 114 */ 115@Since("1.4.0") 116class IDFModel private[ml] ( 117 @Since("1.4.0") override val uid: String, 118 idfModel: feature.IDFModel) 119 extends Model[IDFModel] with IDFBase with MLWritable { 120 121 import IDFModel._ 122 123 /** @group setParam */ 124 @Since("1.4.0") 125 def setInputCol(value: String): this.type = set(inputCol, value) 126 127 /** @group setParam */ 128 @Since("1.4.0") 129 def setOutputCol(value: String): this.type = set(outputCol, value) 130 131 @Since("2.0.0") 132 override def transform(dataset: Dataset[_]): DataFrame = { 133 transformSchema(dataset.schema, logging = true) 134 // TODO: Make the idfModel.transform natively in ml framework to avoid extra conversion. 135 val idf = udf { vec: Vector => idfModel.transform(OldVectors.fromML(vec)).asML } 136 dataset.withColumn($(outputCol), idf(col($(inputCol)))) 137 } 138 139 @Since("1.4.0") 140 override def transformSchema(schema: StructType): StructType = { 141 validateAndTransformSchema(schema) 142 } 143 144 @Since("1.4.1") 145 override def copy(extra: ParamMap): IDFModel = { 146 val copied = new IDFModel(uid, idfModel) 147 copyValues(copied, extra).setParent(parent) 148 } 149 150 /** Returns the IDF vector. */ 151 @Since("2.0.0") 152 def idf: Vector = idfModel.idf.asML 153 154 @Since("1.6.0") 155 override def write: MLWriter = new IDFModelWriter(this) 156} 157 158@Since("1.6.0") 159object IDFModel extends MLReadable[IDFModel] { 160 161 private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { 162 163 private case class Data(idf: Vector) 164 165 override protected def saveImpl(path: String): Unit = { 166 DefaultParamsWriter.saveMetadata(instance, path, sc) 167 val data = Data(instance.idf) 168 val dataPath = new Path(path, "data").toString 169 sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) 170 } 171 } 172 173 private class IDFModelReader extends MLReader[IDFModel] { 174 175 private val className = classOf[IDFModel].getName 176 177 override def load(path: String): IDFModel = { 178 val metadata = DefaultParamsReader.loadMetadata(path, sc, className) 179 val dataPath = new Path(path, "data").toString 180 val data = sparkSession.read.parquet(dataPath) 181 val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf") 182 .select("idf") 183 .head() 184 val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) 185 DefaultParamsReader.getAndSetParams(model, metadata) 186 model 187 } 188 } 189 190 @Since("1.6.0") 191 override def read: MLReader[IDFModel] = new IDFModelReader 192 193 @Since("1.6.0") 194 override def load(path: String): IDFModel = super.load(path) 195} 196