1/*
2* Licensed to the Apache Software Foundation (ASF) under one or more
3* contributor license agreements.  See the NOTICE file distributed with
4* this work for additional information regarding copyright ownership.
5* The ASF licenses this file to You under the Apache License, Version 2.0
6* (the "License"); you may not use this file except in compliance with
7* the License.  You may obtain a copy of the License at
8*
9*    http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*/
17
18package org.apache.spark.sql
19
20import java.{lang => jl}
21
22import scala.collection.JavaConverters._
23
24import org.apache.spark.annotation.InterfaceStability
25import org.apache.spark.sql.catalyst.expressions._
26import org.apache.spark.sql.functions._
27import org.apache.spark.sql.types._
28
29
30/**
31 * Functionality for working with missing data in `DataFrame`s.
32 *
33 * @since 1.3.1
34 */
35@InterfaceStability.Stable
36final class DataFrameNaFunctions private[sql](df: DataFrame) {
37
38  /**
39   * Returns a new `DataFrame` that drops rows containing any null or NaN values.
40   *
41   * @since 1.3.1
42   */
43  def drop(): DataFrame = drop("any", df.columns)
44
45  /**
46   * Returns a new `DataFrame` that drops rows containing null or NaN values.
47   *
48   * If `how` is "any", then drop rows containing any null or NaN values.
49   * If `how` is "all", then drop rows only if every column is null or NaN for that row.
50   *
51   * @since 1.3.1
52   */
53  def drop(how: String): DataFrame = drop(how, df.columns)
54
55  /**
56   * Returns a new `DataFrame` that drops rows containing any null or NaN values
57   * in the specified columns.
58   *
59   * @since 1.3.1
60   */
61  def drop(cols: Array[String]): DataFrame = drop(cols.toSeq)
62
63  /**
64   * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values
65   * in the specified columns.
66   *
67   * @since 1.3.1
68   */
69  def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols)
70
71  /**
72   * Returns a new `DataFrame` that drops rows containing null or NaN values
73   * in the specified columns.
74   *
75   * If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
76   * If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
77   *
78   * @since 1.3.1
79   */
80  def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq)
81
82  /**
83   * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values
84   * in the specified columns.
85   *
86   * If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
87   * If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
88   *
89   * @since 1.3.1
90   */
91  def drop(how: String, cols: Seq[String]): DataFrame = {
92    how.toLowerCase match {
93      case "any" => drop(cols.size, cols)
94      case "all" => drop(1, cols)
95      case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
96    }
97  }
98
99  /**
100   * Returns a new `DataFrame` that drops rows containing
101   * less than `minNonNulls` non-null and non-NaN values.
102   *
103   * @since 1.3.1
104   */
105  def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns)
106
107  /**
108   * Returns a new `DataFrame` that drops rows containing
109   * less than `minNonNulls` non-null and non-NaN values in the specified columns.
110   *
111   * @since 1.3.1
112   */
113  def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq)
114
115  /**
116   * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than
117   * `minNonNulls` non-null and non-NaN values in the specified columns.
118   *
119   * @since 1.3.1
120   */
121  def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
122    // Filtering condition:
123    // only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
124    val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
125    df.filter(Column(predicate))
126  }
127
128  /**
129   * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
130   *
131   * @since 2.1.1
132   */
133  def fill(value: Long): DataFrame = fill(value, df.columns)
134
135  /**
136   * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
137   * @since 1.3.1
138   */
139  def fill(value: Double): DataFrame = fill(value, df.columns)
140
141  /**
142   * Returns a new `DataFrame` that replaces null values in string columns with `value`.
143   *
144   * @since 1.3.1
145   */
146  def fill(value: String): DataFrame = fill(value, df.columns)
147
148  /**
149   * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
150   * If a specified column is not a numeric column, it is ignored.
151   *
152   * @since 2.1.1
153   */
154  def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
155
156  /**
157   * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
158   * If a specified column is not a numeric column, it is ignored.
159   *
160   * @since 1.3.1
161   */
162  def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
163
164  /**
165   * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
166   * numeric columns. If a specified column is not a numeric column, it is ignored.
167   *
168   * @since 2.1.1
169   */
170  def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
171
172  /**
173   * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
174   * numeric columns. If a specified column is not a numeric column, it is ignored.
175   *
176   * @since 1.3.1
177   */
178  def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
179
180
181  /**
182   * Returns a new `DataFrame` that replaces null values in specified string columns.
183   * If a specified column is not a string column, it is ignored.
184   *
185   * @since 1.3.1
186   */
187  def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
188
189  /**
190   * (Scala-specific) Returns a new `DataFrame` that replaces null values in
191   * specified string columns. If a specified column is not a string column, it is ignored.
192   *
193   * @since 1.3.1
194   */
195  def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
196
197  /**
198   * Returns a new `DataFrame` that replaces null values.
199   *
200   * The key of the map is the column name, and the value of the map is the replacement value.
201   * The value must be of the following type:
202   * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`.
203   * Replacement values are cast to the column data type.
204   *
205   * For example, the following replaces null values in column "A" with string "unknown", and
206   * null values in column "B" with numeric value 1.0.
207   * {{{
208   *   import com.google.common.collect.ImmutableMap;
209   *   df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
210   * }}}
211   *
212   * @since 1.3.1
213   */
214  def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq)
215
216  /**
217   * (Scala-specific) Returns a new `DataFrame` that replaces null values.
218   *
219   * The key of the map is the column name, and the value of the map is the replacement value.
220   * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`.
221   * Replacement values are cast to the column data type.
222   *
223   * For example, the following replaces null values in column "A" with string "unknown", and
224   * null values in column "B" with numeric value 1.0.
225   * {{{
226   *   df.na.fill(Map(
227   *     "A" -> "unknown",
228   *     "B" -> 1.0
229   *   ))
230   * }}}
231   *
232   * @since 1.3.1
233   */
234  def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
235
236  /**
237   * Replaces values matching keys in `replacement` map with the corresponding values.
238   * Key and value of `replacement` map must have the same type, and
239   * can only be doubles, strings or booleans.
240   * If `col` is "*", then the replacement is applied on all string columns or numeric columns.
241   *
242   * {{{
243   *   import com.google.common.collect.ImmutableMap;
244   *
245   *   // Replaces all occurrences of 1.0 with 2.0 in column "height".
246   *   df.replace("height", ImmutableMap.of(1.0, 2.0));
247   *
248   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
249   *   df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
250   *
251   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
252   *   df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
253   * }}}
254   *
255   * @param col name of the column to apply the value replacement
256   * @param replacement value replacement map, as explained above
257   *
258   * @since 1.3.1
259   */
260  def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = {
261    replace[T](col, replacement.asScala.toMap)
262  }
263
264  /**
265   * Replaces values matching keys in `replacement` map with the corresponding values.
266   * Key and value of `replacement` map must have the same type, and
267   * can only be doubles, strings or booleans.
268   *
269   * {{{
270   *   import com.google.common.collect.ImmutableMap;
271   *
272   *   // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
273   *   df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0));
274   *
275   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
276   *   df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
277   * }}}
278   *
279   * @param cols list of columns to apply the value replacement
280   * @param replacement value replacement map, as explained above
281   *
282   * @since 1.3.1
283   */
284  def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = {
285    replace(cols.toSeq, replacement.asScala.toMap)
286  }
287
288  /**
289   * (Scala-specific) Replaces values matching keys in `replacement` map.
290   * Key and value of `replacement` map must have the same type, and
291   * can only be doubles, strings or booleans.
292   * If `col` is "*",
293   * then the replacement is applied on all string columns , numeric columns or boolean columns.
294   *
295   * {{{
296   *   // Replaces all occurrences of 1.0 with 2.0 in column "height".
297   *   df.replace("height", Map(1.0 -> 2.0))
298   *
299   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
300   *   df.replace("name", Map("UNKNOWN" -> "unnamed")
301   *
302   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
303   *   df.replace("*", Map("UNKNOWN" -> "unnamed")
304   * }}}
305   *
306   * @param col name of the column to apply the value replacement
307   * @param replacement value replacement map, as explained above
308   *
309   * @since 1.3.1
310   */
311  def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
312    if (col == "*") {
313      replace0(df.columns, replacement)
314    } else {
315      replace0(Seq(col), replacement)
316    }
317  }
318
319  /**
320   * (Scala-specific) Replaces values matching keys in `replacement` map.
321   * Key and value of `replacement` map must have the same type, and
322   * can only be doubles , strings or booleans.
323   *
324   * {{{
325   *   // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
326   *   df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
327   *
328   *   // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
329   *   df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed");
330   * }}}
331   *
332   * @param cols list of columns to apply the value replacement
333   * @param replacement value replacement map, as explained above
334   *
335   * @since 1.3.1
336   */
337  def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement)
338
339  private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
340    if (replacement.isEmpty || cols.isEmpty) {
341      return df
342    }
343
344    // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean]
345    val replacementMap: Map[_, _] = replacement.head._2 match {
346      case v: String => replacement
347      case v: Boolean => replacement
348      case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) }
349    }
350
351    // targetColumnType is either DoubleType or StringType or BooleanType
352    val targetColumnType = replacement.head._1 match {
353      case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType
354      case _: jl.Boolean => BooleanType
355      case _: String => StringType
356    }
357
358    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
359    val projections = df.schema.fields.map { f =>
360      val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
361      if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
362        replaceCol(f, replacementMap)
363      } else if (f.dataType == targetColumnType && shouldReplace) {
364        replaceCol(f, replacementMap)
365      } else {
366        df.col(f.name)
367      }
368    }
369    df.select(projections : _*)
370  }
371
372  private def fillMap(values: Seq[(String, Any)]): DataFrame = {
373    // Error handling
374    values.foreach { case (colName, replaceValue) =>
375      // Check column name exists
376      df.resolve(colName)
377
378      // Check data type
379      replaceValue match {
380        case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: jl.Boolean | _: String =>
381          // This is good
382        case _ => throw new IllegalArgumentException(
383          s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).")
384      }
385    }
386
387    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
388    val projections = df.schema.fields.map { f =>
389      values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
390        v match {
391          case v: jl.Float => fillCol[Float](f, v)
392          case v: jl.Double => fillCol[Double](f, v)
393          case v: jl.Long => fillCol[Long](f, v)
394          case v: jl.Integer => fillCol[Integer](f, v)
395          case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue())
396          case v: String => fillCol[String](f, v)
397        }
398      }.getOrElse(df.col(f.name))
399    }
400    df.select(projections : _*)
401  }
402
403  /**
404   * Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
405   */
406  private def fillCol[T](col: StructField, replacement: T): Column = {
407    val quotedColName = "`" + col.name + "`"
408    val colValue = col.dataType match {
409      case DoubleType | FloatType =>
410        nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
411      case _ => df.col(quotedColName)
412    }
413    coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
414  }
415
416  /**
417   * Returns a [[Column]] expression that replaces value matching key in `replacementMap` with
418   * value in `replacementMap`, using [[CaseWhen]].
419   *
420   * TODO: This can be optimized to use broadcast join when replacementMap is large.
421   */
422  private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
423    val keyExpr = df.col(col.name).expr
424    def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
425    val branches = replacementMap.flatMap { case (source, target) =>
426      Seq(buildExpr(source), buildExpr(target))
427    }.toSeq
428    new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
429  }
430
431  private def convertToDouble(v: Any): Double = v match {
432    case v: Float => v.toDouble
433    case v: Double => v
434    case v: Long => v.toDouble
435    case v: Int => v.toDouble
436    case v => throw new IllegalArgumentException(
437      s"Unsupported value type ${v.getClass.getName} ($v).")
438  }
439
440  /**
441   * Returns a new `DataFrame` that replaces null or NaN values in specified
442   * numeric, string columns. If a specified column is not a numeric, string column,
443   * it is ignored.
444   */
445  private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
446    // the fill[T] which T is  Long/Double,
447    // should apply on all the NumericType Column, for example:
448    // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
449    // input.na.fill(3.1)
450    // the result is (3,164.3), not (null, 164.3)
451    val targetType = value match {
452      case _: Double | _: Long => NumericType
453      case _: String => StringType
454      case _ => throw new IllegalArgumentException(
455        s"Unsupported value type ${value.getClass.getName} ($value).")
456    }
457
458    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
459    val projections = df.schema.fields.map { f =>
460      val typeMatches = (targetType, f.dataType) match {
461        case (NumericType, dt) => dt.isInstanceOf[NumericType]
462        case (StringType, dt) => dt == StringType
463      }
464      // Only fill if the column is part of the cols list.
465      if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
466        fillCol[T](f, value)
467      } else {
468        df.col(f.name)
469      }
470    }
471    df.select(projections : _*)
472  }
473}
474