1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.feature
19
20import org.apache.hadoop.fs.Path
21
22import org.apache.spark.annotation.Since
23import org.apache.spark.ml._
24import org.apache.spark.ml.linalg.{Vector, VectorUDT}
25import org.apache.spark.ml.param._
26import org.apache.spark.ml.param.shared._
27import org.apache.spark.ml.util._
28import org.apache.spark.mllib.feature
29import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
30import org.apache.spark.mllib.util.MLUtils
31import org.apache.spark.rdd.RDD
32import org.apache.spark.sql._
33import org.apache.spark.sql.functions._
34import org.apache.spark.sql.types.StructType
35
36/**
37 * Params for [[IDF]] and [[IDFModel]].
38 */
39private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {
40
41  /**
42   * The minimum number of documents in which a term should appear.
43   * Default: 0
44   * @group param
45   */
46  final val minDocFreq = new IntParam(
47    this, "minDocFreq", "minimum number of documents in which a term should appear for filtering" +
48      " (>= 0)", ParamValidators.gtEq(0))
49
50  setDefault(minDocFreq -> 0)
51
52  /** @group getParam */
53  def getMinDocFreq: Int = $(minDocFreq)
54
55  /**
56   * Validate and transform the input schema.
57   */
58  protected def validateAndTransformSchema(schema: StructType): StructType = {
59    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
60    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
61  }
62}
63
64/**
65 * Compute the Inverse Document Frequency (IDF) given a collection of documents.
66 */
67@Since("1.4.0")
68final class IDF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
69  extends Estimator[IDFModel] with IDFBase with DefaultParamsWritable {
70
71  @Since("1.4.0")
72  def this() = this(Identifiable.randomUID("idf"))
73
74  /** @group setParam */
75  @Since("1.4.0")
76  def setInputCol(value: String): this.type = set(inputCol, value)
77
78  /** @group setParam */
79  @Since("1.4.0")
80  def setOutputCol(value: String): this.type = set(outputCol, value)
81
82  /** @group setParam */
83  @Since("1.4.0")
84  def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
85
86  @Since("2.0.0")
87  override def fit(dataset: Dataset[_]): IDFModel = {
88    transformSchema(dataset.schema, logging = true)
89    val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
90      case Row(v: Vector) => OldVectors.fromML(v)
91    }
92    val idf = new feature.IDF($(minDocFreq)).fit(input)
93    copyValues(new IDFModel(uid, idf).setParent(this))
94  }
95
96  @Since("1.4.0")
97  override def transformSchema(schema: StructType): StructType = {
98    validateAndTransformSchema(schema)
99  }
100
101  @Since("1.4.1")
102  override def copy(extra: ParamMap): IDF = defaultCopy(extra)
103}
104
105@Since("1.6.0")
106object IDF extends DefaultParamsReadable[IDF] {
107
108  @Since("1.6.0")
109  override def load(path: String): IDF = super.load(path)
110}
111
112/**
113 * Model fitted by [[IDF]].
114 */
115@Since("1.4.0")
116class IDFModel private[ml] (
117    @Since("1.4.0") override val uid: String,
118    idfModel: feature.IDFModel)
119  extends Model[IDFModel] with IDFBase with MLWritable {
120
121  import IDFModel._
122
123  /** @group setParam */
124  @Since("1.4.0")
125  def setInputCol(value: String): this.type = set(inputCol, value)
126
127  /** @group setParam */
128  @Since("1.4.0")
129  def setOutputCol(value: String): this.type = set(outputCol, value)
130
131  @Since("2.0.0")
132  override def transform(dataset: Dataset[_]): DataFrame = {
133    transformSchema(dataset.schema, logging = true)
134    // TODO: Make the idfModel.transform natively in ml framework to avoid extra conversion.
135    val idf = udf { vec: Vector => idfModel.transform(OldVectors.fromML(vec)).asML }
136    dataset.withColumn($(outputCol), idf(col($(inputCol))))
137  }
138
139  @Since("1.4.0")
140  override def transformSchema(schema: StructType): StructType = {
141    validateAndTransformSchema(schema)
142  }
143
144  @Since("1.4.1")
145  override def copy(extra: ParamMap): IDFModel = {
146    val copied = new IDFModel(uid, idfModel)
147    copyValues(copied, extra).setParent(parent)
148  }
149
150  /** Returns the IDF vector. */
151  @Since("2.0.0")
152  def idf: Vector = idfModel.idf.asML
153
154  @Since("1.6.0")
155  override def write: MLWriter = new IDFModelWriter(this)
156}
157
158@Since("1.6.0")
159object IDFModel extends MLReadable[IDFModel] {
160
161  private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter {
162
163    private case class Data(idf: Vector)
164
165    override protected def saveImpl(path: String): Unit = {
166      DefaultParamsWriter.saveMetadata(instance, path, sc)
167      val data = Data(instance.idf)
168      val dataPath = new Path(path, "data").toString
169      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
170    }
171  }
172
173  private class IDFModelReader extends MLReader[IDFModel] {
174
175    private val className = classOf[IDFModel].getName
176
177    override def load(path: String): IDFModel = {
178      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
179      val dataPath = new Path(path, "data").toString
180      val data = sparkSession.read.parquet(dataPath)
181      val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf")
182        .select("idf")
183        .head()
184      val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf)))
185      DefaultParamsReader.getAndSetParams(model, metadata)
186      model
187    }
188  }
189
190  @Since("1.6.0")
191  override def read: MLReader[IDFModel] = new IDFModelReader
192
193  @Since("1.6.0")
194  override def load(path: String): IDFModel = super.load(path)
195}
196