1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.expressions
19
20import org.apache.spark.sql.catalyst.InternalRow
21import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
22import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
23import org.apache.spark.sql.catalyst.util.TypeUtils
24import org.apache.spark.sql.types._
25
26
27/**
28 * An expression that is evaluated to the first non-null input.
29 *
30 * {{{
31 *   coalesce(1, 2) => 1
32 *   coalesce(null, 1, 2) => 1
33 *   coalesce(null, null, 2) => 2
34 *   coalesce(null, null, null) => null
35 * }}}
36 */
37// scalastyle:off line.size.limit
38@ExpressionDescription(
39  usage = "_FUNC_(expr1, expr2, ...) - Returns the first non-null argument if exists. Otherwise, null.",
40  extended = """
41    Examples:
42      > SELECT _FUNC_(NULL, 1, NULL);
43       1
44  """)
45// scalastyle:on line.size.limit
46case class Coalesce(children: Seq[Expression]) extends Expression {
47
48  /** Coalesce is nullable if all of its children are nullable, or if it has no children. */
49  override def nullable: Boolean = children.forall(_.nullable)
50
51  // Coalesce is foldable if all children are foldable.
52  override def foldable: Boolean = children.forall(_.foldable)
53
54  override def checkInputDataTypes(): TypeCheckResult = {
55    if (children == Nil) {
56      TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
57    } else {
58      TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
59    }
60  }
61
62  override def dataType: DataType = children.head.dataType
63
64  override def eval(input: InternalRow): Any = {
65    var result: Any = null
66    val childIterator = children.iterator
67    while (childIterator.hasNext && result == null) {
68      result = childIterator.next().eval(input)
69    }
70    result
71  }
72
73  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
74    val first = children(0)
75    val rest = children.drop(1)
76    val firstEval = first.genCode(ctx)
77    ev.copy(code = s"""
78      ${firstEval.code}
79      boolean ${ev.isNull} = ${firstEval.isNull};
80      ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};""" +
81      rest.map { e =>
82      val eval = e.genCode(ctx)
83      s"""
84        if (${ev.isNull}) {
85          ${eval.code}
86          if (!${eval.isNull}) {
87            ${ev.isNull} = false;
88            ${ev.value} = ${eval.value};
89          }
90        }
91      """
92    }.mkString("\n"))
93  }
94}
95
96
97@ExpressionDescription(
98  usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.",
99  extended = """
100    Examples:
101      > SELECT _FUNC_(NULL, array('2'));
102       ["2"]
103  """)
104case class IfNull(left: Expression, right: Expression, child: Expression)
105  extends RuntimeReplaceable {
106
107  def this(left: Expression, right: Expression) = {
108    this(left, right, Coalesce(Seq(left, right)))
109  }
110
111  override def flatArguments: Iterator[Any] = Iterator(left, right)
112  override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
113}
114
115
116@ExpressionDescription(
117  usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise.",
118  extended = """
119   Examples:
120     > SELECT _FUNC_(2, 2);
121      NULL
122  """)
123case class NullIf(left: Expression, right: Expression, child: Expression)
124  extends RuntimeReplaceable {
125
126  def this(left: Expression, right: Expression) = {
127    this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left))
128  }
129
130  override def flatArguments: Iterator[Any] = Iterator(left, right)
131  override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
132}
133
134
135@ExpressionDescription(
136  usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.",
137  extended = """
138    Examples:
139      > SELECT _FUNC_(NULL, array('2'));
140       ["2"]
141  """)
142case class Nvl(left: Expression, right: Expression, child: Expression) extends RuntimeReplaceable {
143
144  def this(left: Expression, right: Expression) = {
145    this(left, right, Coalesce(Seq(left, right)))
146  }
147
148  override def flatArguments: Iterator[Any] = Iterator(left, right)
149  override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
150}
151
152
153// scalastyle:off line.size.limit
154@ExpressionDescription(
155  usage = "_FUNC_(expr1, expr2, expr3) - Returns `expr2` if `expr1` is not null, or `expr3` otherwise.",
156  extended = """
157    Examples:
158      > SELECT _FUNC_(NULL, 2, 1);
159       1
160  """)
161// scalastyle:on line.size.limit
162case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: Expression)
163  extends RuntimeReplaceable {
164
165  def this(expr1: Expression, expr2: Expression, expr3: Expression) = {
166    this(expr1, expr2, expr3, If(IsNotNull(expr1), expr2, expr3))
167  }
168
169  override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3)
170  override def sql: String = s"$prettyName(${expr1.sql}, ${expr2.sql}, ${expr3.sql})"
171}
172
173
174/**
175 * Evaluates to `true` iff it's NaN.
176 */
177@ExpressionDescription(
178  usage = "_FUNC_(expr) - Returns true if `expr` is NaN, or false otherwise.",
179  extended = """
180    Examples:
181      > SELECT _FUNC_(cast('NaN' as double));
182       true
183  """)
184case class IsNaN(child: Expression) extends UnaryExpression
185  with Predicate with ImplicitCastInputTypes {
186
187  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))
188
189  override def nullable: Boolean = false
190
191  override def eval(input: InternalRow): Any = {
192    val value = child.eval(input)
193    if (value == null) {
194      false
195    } else {
196      child.dataType match {
197        case DoubleType => value.asInstanceOf[Double].isNaN
198        case FloatType => value.asInstanceOf[Float].isNaN
199      }
200    }
201  }
202
203  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
204    val eval = child.genCode(ctx)
205    child.dataType match {
206      case DoubleType | FloatType =>
207        ev.copy(code = s"""
208          ${eval.code}
209          ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
210          ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false")
211    }
212  }
213}
214
215/**
216 * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise.
217 * This Expression is useful for mapping NaN values to null.
218 */
219@ExpressionDescription(
220  usage = "_FUNC_(expr1, expr2) - Returns `expr1` if it's not NaN, or `expr2` otherwise.",
221  extended = """
222    Examples:
223      > SELECT _FUNC_(cast('NaN' as double), 123);
224       123.0
225  """)
226case class NaNvl(left: Expression, right: Expression)
227    extends BinaryExpression with ImplicitCastInputTypes {
228
229  override def dataType: DataType = left.dataType
230
231  override def inputTypes: Seq[AbstractDataType] =
232    Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, FloatType))
233
234  override def eval(input: InternalRow): Any = {
235    val value = left.eval(input)
236    if (value == null) {
237      null
238    } else {
239      left.dataType match {
240        case DoubleType =>
241          if (!value.asInstanceOf[Double].isNaN) value else right.eval(input)
242        case FloatType =>
243          if (!value.asInstanceOf[Float].isNaN) value else right.eval(input)
244      }
245    }
246  }
247
248  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
249    val leftGen = left.genCode(ctx)
250    val rightGen = right.genCode(ctx)
251    left.dataType match {
252      case DoubleType | FloatType =>
253        ev.copy(code = s"""
254          ${leftGen.code}
255          boolean ${ev.isNull} = false;
256          ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
257          if (${leftGen.isNull}) {
258            ${ev.isNull} = true;
259          } else {
260            if (!Double.isNaN(${leftGen.value})) {
261              ${ev.value} = ${leftGen.value};
262            } else {
263              ${rightGen.code}
264              if (${rightGen.isNull}) {
265                ${ev.isNull} = true;
266              } else {
267                ${ev.value} = ${rightGen.value};
268              }
269            }
270          }""")
271    }
272  }
273}
274
275
276/**
277 * An expression that is evaluated to true if the input is null.
278 */
279@ExpressionDescription(
280  usage = "_FUNC_(expr) - Returns true if `expr` is null, or false otherwise.",
281  extended = """
282    Examples:
283      > SELECT _FUNC_(1);
284       false
285  """)
286case class IsNull(child: Expression) extends UnaryExpression with Predicate {
287  override def nullable: Boolean = false
288
289  override def eval(input: InternalRow): Any = {
290    child.eval(input) == null
291  }
292
293  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
294    val eval = child.genCode(ctx)
295    ExprCode(code = eval.code, isNull = "false", value = eval.isNull)
296  }
297
298  override def sql: String = s"(${child.sql} IS NULL)"
299}
300
301
302/**
303 * An expression that is evaluated to true if the input is not null.
304 */
305@ExpressionDescription(
306  usage = "_FUNC_(expr) - Returns true if `expr` is not null, or false otherwise.",
307  extended = """
308    Examples:
309      > SELECT _FUNC_(1);
310       true
311  """)
312case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
313  override def nullable: Boolean = false
314
315  override def eval(input: InternalRow): Any = {
316    child.eval(input) != null
317  }
318
319  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
320    val eval = child.genCode(ctx)
321    ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))")
322  }
323
324  override def sql: String = s"(${child.sql} IS NOT NULL)"
325}
326
327
328/**
329 * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
330 */
331case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
332  override def nullable: Boolean = false
333  override def foldable: Boolean = children.forall(_.foldable)
334  override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
335
336  private[this] val childrenArray = children.toArray
337
338  override def eval(input: InternalRow): Boolean = {
339    var numNonNulls = 0
340    var i = 0
341    while (i < childrenArray.length && numNonNulls < n) {
342      val evalC = childrenArray(i).eval(input)
343      if (evalC != null) {
344        childrenArray(i).dataType match {
345          case DoubleType =>
346            if (!evalC.asInstanceOf[Double].isNaN) numNonNulls += 1
347          case FloatType =>
348            if (!evalC.asInstanceOf[Float].isNaN) numNonNulls += 1
349          case _ => numNonNulls += 1
350        }
351      }
352      i += 1
353    }
354    numNonNulls >= n
355  }
356
357  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
358    val nonnull = ctx.freshName("nonnull")
359    val code = children.map { e =>
360      val eval = e.genCode(ctx)
361      e.dataType match {
362        case DoubleType | FloatType =>
363          s"""
364            if ($nonnull < $n) {
365              ${eval.code}
366              if (!${eval.isNull} && !Double.isNaN(${eval.value})) {
367                $nonnull += 1;
368              }
369            }
370          """
371        case _ =>
372          s"""
373            if ($nonnull < $n) {
374              ${eval.code}
375              if (!${eval.isNull}) {
376                $nonnull += 1;
377              }
378            }
379          """
380      }
381    }.mkString("\n")
382    ev.copy(code = s"""
383      int $nonnull = 0;
384      $code
385      boolean ${ev.value} = $nonnull >= $n;""", isNull = "false")
386  }
387}
388