1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.expressions 19 20import org.apache.spark.sql.catalyst.InternalRow 21import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} 22import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} 23import org.apache.spark.sql.catalyst.util.TypeUtils 24import org.apache.spark.sql.types._ 25 26 27/** 28 * An expression that is evaluated to the first non-null input. 29 * 30 * {{{ 31 * coalesce(1, 2) => 1 32 * coalesce(null, 1, 2) => 1 33 * coalesce(null, null, 2) => 2 34 * coalesce(null, null, null) => null 35 * }}} 36 */ 37// scalastyle:off line.size.limit 38@ExpressionDescription( 39 usage = "_FUNC_(expr1, expr2, ...) - Returns the first non-null argument if exists. Otherwise, null.", 40 extended = """ 41 Examples: 42 > SELECT _FUNC_(NULL, 1, NULL); 43 1 44 """) 45// scalastyle:on line.size.limit 46case class Coalesce(children: Seq[Expression]) extends Expression { 47 48 /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ 49 override def nullable: Boolean = children.forall(_.nullable) 50 51 // Coalesce is foldable if all children are foldable. 52 override def foldable: Boolean = children.forall(_.foldable) 53 54 override def checkInputDataTypes(): TypeCheckResult = { 55 if (children == Nil) { 56 TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") 57 } else { 58 TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") 59 } 60 } 61 62 override def dataType: DataType = children.head.dataType 63 64 override def eval(input: InternalRow): Any = { 65 var result: Any = null 66 val childIterator = children.iterator 67 while (childIterator.hasNext && result == null) { 68 result = childIterator.next().eval(input) 69 } 70 result 71 } 72 73 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 74 val first = children(0) 75 val rest = children.drop(1) 76 val firstEval = first.genCode(ctx) 77 ev.copy(code = s""" 78 ${firstEval.code} 79 boolean ${ev.isNull} = ${firstEval.isNull}; 80 ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};""" + 81 rest.map { e => 82 val eval = e.genCode(ctx) 83 s""" 84 if (${ev.isNull}) { 85 ${eval.code} 86 if (!${eval.isNull}) { 87 ${ev.isNull} = false; 88 ${ev.value} = ${eval.value}; 89 } 90 } 91 """ 92 }.mkString("\n")) 93 } 94} 95 96 97@ExpressionDescription( 98 usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.", 99 extended = """ 100 Examples: 101 > SELECT _FUNC_(NULL, array('2')); 102 ["2"] 103 """) 104case class IfNull(left: Expression, right: Expression, child: Expression) 105 extends RuntimeReplaceable { 106 107 def this(left: Expression, right: Expression) = { 108 this(left, right, Coalesce(Seq(left, right))) 109 } 110 111 override def flatArguments: Iterator[Any] = Iterator(left, right) 112 override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" 113} 114 115 116@ExpressionDescription( 117 usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise.", 118 extended = """ 119 Examples: 120 > SELECT _FUNC_(2, 2); 121 NULL 122 """) 123case class NullIf(left: Expression, right: Expression, child: Expression) 124 extends RuntimeReplaceable { 125 126 def this(left: Expression, right: Expression) = { 127 this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left)) 128 } 129 130 override def flatArguments: Iterator[Any] = Iterator(left, right) 131 override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" 132} 133 134 135@ExpressionDescription( 136 usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.", 137 extended = """ 138 Examples: 139 > SELECT _FUNC_(NULL, array('2')); 140 ["2"] 141 """) 142case class Nvl(left: Expression, right: Expression, child: Expression) extends RuntimeReplaceable { 143 144 def this(left: Expression, right: Expression) = { 145 this(left, right, Coalesce(Seq(left, right))) 146 } 147 148 override def flatArguments: Iterator[Any] = Iterator(left, right) 149 override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" 150} 151 152 153// scalastyle:off line.size.limit 154@ExpressionDescription( 155 usage = "_FUNC_(expr1, expr2, expr3) - Returns `expr2` if `expr1` is not null, or `expr3` otherwise.", 156 extended = """ 157 Examples: 158 > SELECT _FUNC_(NULL, 2, 1); 159 1 160 """) 161// scalastyle:on line.size.limit 162case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: Expression) 163 extends RuntimeReplaceable { 164 165 def this(expr1: Expression, expr2: Expression, expr3: Expression) = { 166 this(expr1, expr2, expr3, If(IsNotNull(expr1), expr2, expr3)) 167 } 168 169 override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3) 170 override def sql: String = s"$prettyName(${expr1.sql}, ${expr2.sql}, ${expr3.sql})" 171} 172 173 174/** 175 * Evaluates to `true` iff it's NaN. 176 */ 177@ExpressionDescription( 178 usage = "_FUNC_(expr) - Returns true if `expr` is NaN, or false otherwise.", 179 extended = """ 180 Examples: 181 > SELECT _FUNC_(cast('NaN' as double)); 182 true 183 """) 184case class IsNaN(child: Expression) extends UnaryExpression 185 with Predicate with ImplicitCastInputTypes { 186 187 override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) 188 189 override def nullable: Boolean = false 190 191 override def eval(input: InternalRow): Any = { 192 val value = child.eval(input) 193 if (value == null) { 194 false 195 } else { 196 child.dataType match { 197 case DoubleType => value.asInstanceOf[Double].isNaN 198 case FloatType => value.asInstanceOf[Float].isNaN 199 } 200 } 201 } 202 203 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 204 val eval = child.genCode(ctx) 205 child.dataType match { 206 case DoubleType | FloatType => 207 ev.copy(code = s""" 208 ${eval.code} 209 ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; 210 ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") 211 } 212 } 213} 214 215/** 216 * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise. 217 * This Expression is useful for mapping NaN values to null. 218 */ 219@ExpressionDescription( 220 usage = "_FUNC_(expr1, expr2) - Returns `expr1` if it's not NaN, or `expr2` otherwise.", 221 extended = """ 222 Examples: 223 > SELECT _FUNC_(cast('NaN' as double), 123); 224 123.0 225 """) 226case class NaNvl(left: Expression, right: Expression) 227 extends BinaryExpression with ImplicitCastInputTypes { 228 229 override def dataType: DataType = left.dataType 230 231 override def inputTypes: Seq[AbstractDataType] = 232 Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, FloatType)) 233 234 override def eval(input: InternalRow): Any = { 235 val value = left.eval(input) 236 if (value == null) { 237 null 238 } else { 239 left.dataType match { 240 case DoubleType => 241 if (!value.asInstanceOf[Double].isNaN) value else right.eval(input) 242 case FloatType => 243 if (!value.asInstanceOf[Float].isNaN) value else right.eval(input) 244 } 245 } 246 } 247 248 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 249 val leftGen = left.genCode(ctx) 250 val rightGen = right.genCode(ctx) 251 left.dataType match { 252 case DoubleType | FloatType => 253 ev.copy(code = s""" 254 ${leftGen.code} 255 boolean ${ev.isNull} = false; 256 ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; 257 if (${leftGen.isNull}) { 258 ${ev.isNull} = true; 259 } else { 260 if (!Double.isNaN(${leftGen.value})) { 261 ${ev.value} = ${leftGen.value}; 262 } else { 263 ${rightGen.code} 264 if (${rightGen.isNull}) { 265 ${ev.isNull} = true; 266 } else { 267 ${ev.value} = ${rightGen.value}; 268 } 269 } 270 }""") 271 } 272 } 273} 274 275 276/** 277 * An expression that is evaluated to true if the input is null. 278 */ 279@ExpressionDescription( 280 usage = "_FUNC_(expr) - Returns true if `expr` is null, or false otherwise.", 281 extended = """ 282 Examples: 283 > SELECT _FUNC_(1); 284 false 285 """) 286case class IsNull(child: Expression) extends UnaryExpression with Predicate { 287 override def nullable: Boolean = false 288 289 override def eval(input: InternalRow): Any = { 290 child.eval(input) == null 291 } 292 293 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 294 val eval = child.genCode(ctx) 295 ExprCode(code = eval.code, isNull = "false", value = eval.isNull) 296 } 297 298 override def sql: String = s"(${child.sql} IS NULL)" 299} 300 301 302/** 303 * An expression that is evaluated to true if the input is not null. 304 */ 305@ExpressionDescription( 306 usage = "_FUNC_(expr) - Returns true if `expr` is not null, or false otherwise.", 307 extended = """ 308 Examples: 309 > SELECT _FUNC_(1); 310 true 311 """) 312case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { 313 override def nullable: Boolean = false 314 315 override def eval(input: InternalRow): Any = { 316 child.eval(input) != null 317 } 318 319 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 320 val eval = child.genCode(ctx) 321 ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") 322 } 323 324 override def sql: String = s"(${child.sql} IS NOT NULL)" 325} 326 327 328/** 329 * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. 330 */ 331case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { 332 override def nullable: Boolean = false 333 override def foldable: Boolean = children.forall(_.foldable) 334 override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" 335 336 private[this] val childrenArray = children.toArray 337 338 override def eval(input: InternalRow): Boolean = { 339 var numNonNulls = 0 340 var i = 0 341 while (i < childrenArray.length && numNonNulls < n) { 342 val evalC = childrenArray(i).eval(input) 343 if (evalC != null) { 344 childrenArray(i).dataType match { 345 case DoubleType => 346 if (!evalC.asInstanceOf[Double].isNaN) numNonNulls += 1 347 case FloatType => 348 if (!evalC.asInstanceOf[Float].isNaN) numNonNulls += 1 349 case _ => numNonNulls += 1 350 } 351 } 352 i += 1 353 } 354 numNonNulls >= n 355 } 356 357 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 358 val nonnull = ctx.freshName("nonnull") 359 val code = children.map { e => 360 val eval = e.genCode(ctx) 361 e.dataType match { 362 case DoubleType | FloatType => 363 s""" 364 if ($nonnull < $n) { 365 ${eval.code} 366 if (!${eval.isNull} && !Double.isNaN(${eval.value})) { 367 $nonnull += 1; 368 } 369 } 370 """ 371 case _ => 372 s""" 373 if ($nonnull < $n) { 374 ${eval.code} 375 if (!${eval.isNull}) { 376 $nonnull += 1; 377 } 378 } 379 """ 380 } 381 }.mkString("\n") 382 ev.copy(code = s""" 383 int $nonnull = 0; 384 $code 385 boolean ${ev.value} = $nonnull >= $n;""", isNull = "false") 386 } 387} 388