1/* 2* Licensed to the Apache Software Foundation (ASF) under one or more 3* contributor license agreements. See the NOTICE file distributed with 4* this work for additional information regarding copyright ownership. 5* The ASF licenses this file to You under the Apache License, Version 2.0 6* (the "License"); you may not use this file except in compliance with 7* the License. You may obtain a copy of the License at 8* 9* http://www.apache.org/licenses/LICENSE-2.0 10* 11* Unless required by applicable law or agreed to in writing, software 12* distributed under the License is distributed on an "AS IS" BASIS, 13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14* See the License for the specific language governing permissions and 15* limitations under the License. 16*/ 17 18package org.apache.spark.sql 19 20import java.{lang => jl} 21 22import scala.collection.JavaConverters._ 23 24import org.apache.spark.annotation.InterfaceStability 25import org.apache.spark.sql.catalyst.expressions._ 26import org.apache.spark.sql.functions._ 27import org.apache.spark.sql.types._ 28 29 30/** 31 * Functionality for working with missing data in `DataFrame`s. 32 * 33 * @since 1.3.1 34 */ 35@InterfaceStability.Stable 36final class DataFrameNaFunctions private[sql](df: DataFrame) { 37 38 /** 39 * Returns a new `DataFrame` that drops rows containing any null or NaN values. 40 * 41 * @since 1.3.1 42 */ 43 def drop(): DataFrame = drop("any", df.columns) 44 45 /** 46 * Returns a new `DataFrame` that drops rows containing null or NaN values. 47 * 48 * If `how` is "any", then drop rows containing any null or NaN values. 49 * If `how` is "all", then drop rows only if every column is null or NaN for that row. 50 * 51 * @since 1.3.1 52 */ 53 def drop(how: String): DataFrame = drop(how, df.columns) 54 55 /** 56 * Returns a new `DataFrame` that drops rows containing any null or NaN values 57 * in the specified columns. 58 * 59 * @since 1.3.1 60 */ 61 def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) 62 63 /** 64 * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values 65 * in the specified columns. 66 * 67 * @since 1.3.1 68 */ 69 def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) 70 71 /** 72 * Returns a new `DataFrame` that drops rows containing null or NaN values 73 * in the specified columns. 74 * 75 * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. 76 * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. 77 * 78 * @since 1.3.1 79 */ 80 def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) 81 82 /** 83 * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values 84 * in the specified columns. 85 * 86 * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. 87 * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. 88 * 89 * @since 1.3.1 90 */ 91 def drop(how: String, cols: Seq[String]): DataFrame = { 92 how.toLowerCase match { 93 case "any" => drop(cols.size, cols) 94 case "all" => drop(1, cols) 95 case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") 96 } 97 } 98 99 /** 100 * Returns a new `DataFrame` that drops rows containing 101 * less than `minNonNulls` non-null and non-NaN values. 102 * 103 * @since 1.3.1 104 */ 105 def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) 106 107 /** 108 * Returns a new `DataFrame` that drops rows containing 109 * less than `minNonNulls` non-null and non-NaN values in the specified columns. 110 * 111 * @since 1.3.1 112 */ 113 def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq) 114 115 /** 116 * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than 117 * `minNonNulls` non-null and non-NaN values in the specified columns. 118 * 119 * @since 1.3.1 120 */ 121 def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { 122 // Filtering condition: 123 // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. 124 val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) 125 df.filter(Column(predicate)) 126 } 127 128 /** 129 * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. 130 * 131 * @since 2.1.1 132 */ 133 def fill(value: Long): DataFrame = fill(value, df.columns) 134 135 /** 136 * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. 137 * @since 1.3.1 138 */ 139 def fill(value: Double): DataFrame = fill(value, df.columns) 140 141 /** 142 * Returns a new `DataFrame` that replaces null values in string columns with `value`. 143 * 144 * @since 1.3.1 145 */ 146 def fill(value: String): DataFrame = fill(value, df.columns) 147 148 /** 149 * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. 150 * If a specified column is not a numeric column, it is ignored. 151 * 152 * @since 2.1.1 153 */ 154 def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq) 155 156 /** 157 * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. 158 * If a specified column is not a numeric column, it is ignored. 159 * 160 * @since 1.3.1 161 */ 162 def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) 163 164 /** 165 * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified 166 * numeric columns. If a specified column is not a numeric column, it is ignored. 167 * 168 * @since 2.1.1 169 */ 170 def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols) 171 172 /** 173 * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified 174 * numeric columns. If a specified column is not a numeric column, it is ignored. 175 * 176 * @since 1.3.1 177 */ 178 def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols) 179 180 181 /** 182 * Returns a new `DataFrame` that replaces null values in specified string columns. 183 * If a specified column is not a string column, it is ignored. 184 * 185 * @since 1.3.1 186 */ 187 def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq) 188 189 /** 190 * (Scala-specific) Returns a new `DataFrame` that replaces null values in 191 * specified string columns. If a specified column is not a string column, it is ignored. 192 * 193 * @since 1.3.1 194 */ 195 def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) 196 197 /** 198 * Returns a new `DataFrame` that replaces null values. 199 * 200 * The key of the map is the column name, and the value of the map is the replacement value. 201 * The value must be of the following type: 202 * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. 203 * Replacement values are cast to the column data type. 204 * 205 * For example, the following replaces null values in column "A" with string "unknown", and 206 * null values in column "B" with numeric value 1.0. 207 * {{{ 208 * import com.google.common.collect.ImmutableMap; 209 * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); 210 * }}} 211 * 212 * @since 1.3.1 213 */ 214 def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq) 215 216 /** 217 * (Scala-specific) Returns a new `DataFrame` that replaces null values. 218 * 219 * The key of the map is the column name, and the value of the map is the replacement value. 220 * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. 221 * Replacement values are cast to the column data type. 222 * 223 * For example, the following replaces null values in column "A" with string "unknown", and 224 * null values in column "B" with numeric value 1.0. 225 * {{{ 226 * df.na.fill(Map( 227 * "A" -> "unknown", 228 * "B" -> 1.0 229 * )) 230 * }}} 231 * 232 * @since 1.3.1 233 */ 234 def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq) 235 236 /** 237 * Replaces values matching keys in `replacement` map with the corresponding values. 238 * Key and value of `replacement` map must have the same type, and 239 * can only be doubles, strings or booleans. 240 * If `col` is "*", then the replacement is applied on all string columns or numeric columns. 241 * 242 * {{{ 243 * import com.google.common.collect.ImmutableMap; 244 * 245 * // Replaces all occurrences of 1.0 with 2.0 in column "height". 246 * df.replace("height", ImmutableMap.of(1.0, 2.0)); 247 * 248 * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". 249 * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); 250 * 251 * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. 252 * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); 253 * }}} 254 * 255 * @param col name of the column to apply the value replacement 256 * @param replacement value replacement map, as explained above 257 * 258 * @since 1.3.1 259 */ 260 def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { 261 replace[T](col, replacement.asScala.toMap) 262 } 263 264 /** 265 * Replaces values matching keys in `replacement` map with the corresponding values. 266 * Key and value of `replacement` map must have the same type, and 267 * can only be doubles, strings or booleans. 268 * 269 * {{{ 270 * import com.google.common.collect.ImmutableMap; 271 * 272 * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". 273 * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); 274 * 275 * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". 276 * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); 277 * }}} 278 * 279 * @param cols list of columns to apply the value replacement 280 * @param replacement value replacement map, as explained above 281 * 282 * @since 1.3.1 283 */ 284 def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { 285 replace(cols.toSeq, replacement.asScala.toMap) 286 } 287 288 /** 289 * (Scala-specific) Replaces values matching keys in `replacement` map. 290 * Key and value of `replacement` map must have the same type, and 291 * can only be doubles, strings or booleans. 292 * If `col` is "*", 293 * then the replacement is applied on all string columns , numeric columns or boolean columns. 294 * 295 * {{{ 296 * // Replaces all occurrences of 1.0 with 2.0 in column "height". 297 * df.replace("height", Map(1.0 -> 2.0)) 298 * 299 * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". 300 * df.replace("name", Map("UNKNOWN" -> "unnamed") 301 * 302 * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. 303 * df.replace("*", Map("UNKNOWN" -> "unnamed") 304 * }}} 305 * 306 * @param col name of the column to apply the value replacement 307 * @param replacement value replacement map, as explained above 308 * 309 * @since 1.3.1 310 */ 311 def replace[T](col: String, replacement: Map[T, T]): DataFrame = { 312 if (col == "*") { 313 replace0(df.columns, replacement) 314 } else { 315 replace0(Seq(col), replacement) 316 } 317 } 318 319 /** 320 * (Scala-specific) Replaces values matching keys in `replacement` map. 321 * Key and value of `replacement` map must have the same type, and 322 * can only be doubles , strings or booleans. 323 * 324 * {{{ 325 * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". 326 * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); 327 * 328 * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". 329 * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); 330 * }}} 331 * 332 * @param cols list of columns to apply the value replacement 333 * @param replacement value replacement map, as explained above 334 * 335 * @since 1.3.1 336 */ 337 def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement) 338 339 private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { 340 if (replacement.isEmpty || cols.isEmpty) { 341 return df 342 } 343 344 // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] 345 val replacementMap: Map[_, _] = replacement.head._2 match { 346 case v: String => replacement 347 case v: Boolean => replacement 348 case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } 349 } 350 351 // targetColumnType is either DoubleType or StringType or BooleanType 352 val targetColumnType = replacement.head._1 match { 353 case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType 354 case _: jl.Boolean => BooleanType 355 case _: String => StringType 356 } 357 358 val columnEquals = df.sparkSession.sessionState.analyzer.resolver 359 val projections = df.schema.fields.map { f => 360 val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) 361 if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { 362 replaceCol(f, replacementMap) 363 } else if (f.dataType == targetColumnType && shouldReplace) { 364 replaceCol(f, replacementMap) 365 } else { 366 df.col(f.name) 367 } 368 } 369 df.select(projections : _*) 370 } 371 372 private def fillMap(values: Seq[(String, Any)]): DataFrame = { 373 // Error handling 374 values.foreach { case (colName, replaceValue) => 375 // Check column name exists 376 df.resolve(colName) 377 378 // Check data type 379 replaceValue match { 380 case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: jl.Boolean | _: String => 381 // This is good 382 case _ => throw new IllegalArgumentException( 383 s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") 384 } 385 } 386 387 val columnEquals = df.sparkSession.sessionState.analyzer.resolver 388 val projections = df.schema.fields.map { f => 389 values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => 390 v match { 391 case v: jl.Float => fillCol[Float](f, v) 392 case v: jl.Double => fillCol[Double](f, v) 393 case v: jl.Long => fillCol[Long](f, v) 394 case v: jl.Integer => fillCol[Integer](f, v) 395 case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) 396 case v: String => fillCol[String](f, v) 397 } 398 }.getOrElse(df.col(f.name)) 399 } 400 df.select(projections : _*) 401 } 402 403 /** 404 * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. 405 */ 406 private def fillCol[T](col: StructField, replacement: T): Column = { 407 val quotedColName = "`" + col.name + "`" 408 val colValue = col.dataType match { 409 case DoubleType | FloatType => 410 nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types 411 case _ => df.col(quotedColName) 412 } 413 coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) 414 } 415 416 /** 417 * Returns a [[Column]] expression that replaces value matching key in `replacementMap` with 418 * value in `replacementMap`, using [[CaseWhen]]. 419 * 420 * TODO: This can be optimized to use broadcast join when replacementMap is large. 421 */ 422 private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { 423 val keyExpr = df.col(col.name).expr 424 def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) 425 val branches = replacementMap.flatMap { case (source, target) => 426 Seq(buildExpr(source), buildExpr(target)) 427 }.toSeq 428 new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) 429 } 430 431 private def convertToDouble(v: Any): Double = v match { 432 case v: Float => v.toDouble 433 case v: Double => v 434 case v: Long => v.toDouble 435 case v: Int => v.toDouble 436 case v => throw new IllegalArgumentException( 437 s"Unsupported value type ${v.getClass.getName} ($v).") 438 } 439 440 /** 441 * Returns a new `DataFrame` that replaces null or NaN values in specified 442 * numeric, string columns. If a specified column is not a numeric, string column, 443 * it is ignored. 444 */ 445 private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { 446 // the fill[T] which T is Long/Double, 447 // should apply on all the NumericType Column, for example: 448 // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b") 449 // input.na.fill(3.1) 450 // the result is (3,164.3), not (null, 164.3) 451 val targetType = value match { 452 case _: Double | _: Long => NumericType 453 case _: String => StringType 454 case _ => throw new IllegalArgumentException( 455 s"Unsupported value type ${value.getClass.getName} ($value).") 456 } 457 458 val columnEquals = df.sparkSession.sessionState.analyzer.resolver 459 val projections = df.schema.fields.map { f => 460 val typeMatches = (targetType, f.dataType) match { 461 case (NumericType, dt) => dt.isInstanceOf[NumericType] 462 case (StringType, dt) => dt == StringType 463 } 464 // Only fill if the column is part of the cols list. 465 if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { 466 fillCol[T](f, value) 467 } else { 468 df.col(f.name) 469 } 470 } 471 df.select(projections : _*) 472 } 473} 474