1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.expressions
19
20import org.apache.spark.sql.catalyst.InternalRow
21import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22import org.apache.spark.sql.catalyst.expressions.codegen._
23import org.apache.spark.sql.types._
24
25// scalastyle:off line.size.limit
26@ExpressionDescription(
27  usage = "_FUNC_(expr1, expr2, expr3) - If `expr1` evaluates to true, then returns `expr2`; otherwise returns `expr3`.",
28  extended = """
29    Examples:
30      > SELECT _FUNC_(1 < 2, 'a', 'b');
31       a
32  """)
33// scalastyle:on line.size.limit
34case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
35  extends Expression {
36
37  override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
38  override def nullable: Boolean = trueValue.nullable || falseValue.nullable
39
40  override def checkInputDataTypes(): TypeCheckResult = {
41    if (predicate.dataType != BooleanType) {
42      TypeCheckResult.TypeCheckFailure(
43        s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
44    } else if (!trueValue.dataType.sameType(falseValue.dataType)) {
45      TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
46        s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
47    } else {
48      TypeCheckResult.TypeCheckSuccess
49    }
50  }
51
52  override def dataType: DataType = trueValue.dataType
53
54  override def eval(input: InternalRow): Any = {
55    if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) {
56      trueValue.eval(input)
57    } else {
58      falseValue.eval(input)
59    }
60  }
61
62  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
63    val condEval = predicate.genCode(ctx)
64    val trueEval = trueValue.genCode(ctx)
65    val falseEval = falseValue.genCode(ctx)
66
67    // place generated code of condition, true value and false value in separate methods if
68    // their code combined is large
69    val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length
70    val generatedCode = if (combinedLength > 1024 &&
71      // Split these expressions only if they are created from a row object
72      (ctx.INPUT_ROW != null && ctx.currentVars == null)) {
73
74      val (condFuncName, condGlobalIsNull, condGlobalValue) =
75        createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr")
76      val (trueFuncName, trueGlobalIsNull, trueGlobalValue) =
77        createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr")
78      val (falseFuncName, falseGlobalIsNull, falseGlobalValue) =
79        createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr")
80      s"""
81        $condFuncName(${ctx.INPUT_ROW});
82        boolean ${ev.isNull} = false;
83        ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
84        if (!$condGlobalIsNull && $condGlobalValue) {
85          $trueFuncName(${ctx.INPUT_ROW});
86          ${ev.isNull} = $trueGlobalIsNull;
87          ${ev.value} = $trueGlobalValue;
88        } else {
89          $falseFuncName(${ctx.INPUT_ROW});
90          ${ev.isNull} = $falseGlobalIsNull;
91          ${ev.value} = $falseGlobalValue;
92        }
93      """
94    }
95    else {
96      s"""
97        ${condEval.code}
98        boolean ${ev.isNull} = false;
99        ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
100        if (!${condEval.isNull} && ${condEval.value}) {
101          ${trueEval.code}
102          ${ev.isNull} = ${trueEval.isNull};
103          ${ev.value} = ${trueEval.value};
104        } else {
105          ${falseEval.code}
106          ${ev.isNull} = ${falseEval.isNull};
107          ${ev.value} = ${falseEval.value};
108        }
109      """
110    }
111
112    ev.copy(code = generatedCode)
113  }
114
115  private def createAndAddFunction(
116      ctx: CodegenContext,
117      ev: ExprCode,
118      dataType: DataType,
119      baseFuncName: String): (String, String, String) = {
120    val globalIsNull = ctx.freshName("isNull")
121    ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;")
122    val globalValue = ctx.freshName("value")
123    ctx.addMutableState(ctx.javaType(dataType), globalValue,
124      s"$globalValue = ${ctx.defaultValue(dataType)};")
125    val funcName = ctx.freshName(baseFuncName)
126    val funcBody =
127      s"""
128         |private void $funcName(InternalRow ${ctx.INPUT_ROW}) {
129         |  ${ev.code.trim}
130         |  $globalIsNull = ${ev.isNull};
131         |  $globalValue = ${ev.value};
132         |}
133         """.stripMargin
134    ctx.addNewFunction(funcName, funcBody)
135    (funcName, globalIsNull, globalValue)
136  }
137
138  override def toString: String = s"if ($predicate) $trueValue else $falseValue"
139
140  override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
141}
142
143/**
144 * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
145 *
146 * @param branches seq of (branch condition, branch value)
147 * @param elseValue optional value for the else branch
148 */
149abstract class CaseWhenBase(
150    branches: Seq[(Expression, Expression)],
151    elseValue: Option[Expression])
152  extends Expression with Serializable {
153
154  override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
155
156  // both then and else expressions should be considered.
157  def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
158
159  def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
160    case Seq(dt1, dt2) => dt1.sameType(dt2)
161  }
162
163  override def dataType: DataType = branches.head._2.dataType
164
165  override def nullable: Boolean = {
166    // Result is nullable if any of the branch is nullable, or if the else value is nullable
167    branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
168  }
169
170  override def checkInputDataTypes(): TypeCheckResult = {
171    // Make sure all branch conditions are boolean types.
172    if (valueTypesEqual) {
173      if (branches.forall(_._1.dataType == BooleanType)) {
174        TypeCheckResult.TypeCheckSuccess
175      } else {
176        val index = branches.indexWhere(_._1.dataType != BooleanType)
177        TypeCheckResult.TypeCheckFailure(
178          s"WHEN expressions in CaseWhen should all be boolean type, " +
179            s"but the ${index + 1}th when expression's type is ${branches(index)._1}")
180      }
181    } else {
182      TypeCheckResult.TypeCheckFailure(
183        "THEN and ELSE expressions should all be same type or coercible to a common type")
184    }
185  }
186
187  override def eval(input: InternalRow): Any = {
188    var i = 0
189    val size = branches.size
190    while (i < size) {
191      if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) {
192        return branches(i)._2.eval(input)
193      }
194      i += 1
195    }
196    if (elseValue.isDefined) {
197      return elseValue.get.eval(input)
198    } else {
199      return null
200    }
201  }
202
203  override def toString: String = {
204    val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
205    val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
206    "CASE" + cases + elseCase + " END"
207  }
208
209  override def sql: String = {
210    val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
211    val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
212    "CASE" + cases + elseCase + " END"
213  }
214}
215
216
217/**
218 * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
219 * When a = true, returns b; when c = true, returns d; else returns e.
220 *
221 * @param branches seq of (branch condition, branch value)
222 * @param elseValue optional value for the else branch
223 */
224// scalastyle:off line.size.limit
225@ExpressionDescription(
226  usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; when `expr3` = true, return `expr4`; else return `expr5`.")
227// scalastyle:on line.size.limit
228case class CaseWhen(
229    val branches: Seq[(Expression, Expression)],
230    val elseValue: Option[Expression] = None)
231  extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
232
233  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
234    super[CodegenFallback].doGenCode(ctx, ev)
235  }
236
237  def toCodegen(): CaseWhenCodegen = {
238    CaseWhenCodegen(branches, elseValue)
239  }
240}
241
242/**
243 * CaseWhen expression used when code generation condition is satisfied.
244 * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
245 *
246 * @param branches seq of (branch condition, branch value)
247 * @param elseValue optional value for the else branch
248 */
249case class CaseWhenCodegen(
250    val branches: Seq[(Expression, Expression)],
251    val elseValue: Option[Expression] = None)
252  extends CaseWhenBase(branches, elseValue) with Serializable {
253
254  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
255    // Generate code that looks like:
256    //
257    // condA = ...
258    // if (condA) {
259    //   valueA
260    // } else {
261    //   condB = ...
262    //   if (condB) {
263    //     valueB
264    //   } else {
265    //     condC = ...
266    //     if (condC) {
267    //       valueC
268    //     } else {
269    //       elseValue
270    //     }
271    //   }
272    // }
273    val cases = branches.map { case (condExpr, valueExpr) =>
274      val cond = condExpr.genCode(ctx)
275      val res = valueExpr.genCode(ctx)
276      s"""
277        ${cond.code}
278        if (!${cond.isNull} && ${cond.value}) {
279          ${res.code}
280          ${ev.isNull} = ${res.isNull};
281          ${ev.value} = ${res.value};
282        }
283      """
284    }
285
286    var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
287
288    elseValue.foreach { elseExpr =>
289      val res = elseExpr.genCode(ctx)
290      generatedCode +=
291        s"""
292          ${res.code}
293          ${ev.isNull} = ${res.isNull};
294          ${ev.value} = ${res.value};
295        """
296    }
297
298    generatedCode += "}\n" * cases.size
299
300    ev.copy(code = s"""
301      boolean ${ev.isNull} = true;
302      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
303      $generatedCode""")
304  }
305}
306
307/** Factory methods for CaseWhen. */
308object CaseWhen {
309  def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
310    CaseWhen(branches, Option(elseValue))
311  }
312
313  /**
314   * A factory method to facilitate the creation of this expression when used in parsers.
315   *
316   * @param branches Expressions at even position are the branch conditions, and expressions at odd
317   *                 position are branch values.
318   */
319  def createFromParser(branches: Seq[Expression]): CaseWhen = {
320    val cases = branches.grouped(2).flatMap {
321      case cond :: value :: Nil => Some((cond, value))
322      case value :: Nil => None
323    }.toArray.toSeq  // force materialization to make the seq serializable
324    val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
325    CaseWhen(cases, elseValue)
326  }
327}
328
329/**
330 * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
331 * When a = b, returns c; when a = d, returns e; else returns f.
332 */
333object CaseKeyWhen {
334  def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
335    val cases = branches.grouped(2).flatMap {
336      case cond :: value :: Nil => Some((EqualTo(key, cond), value))
337      case value :: Nil => None
338    }.toArray.toSeq  // force materialization to make the seq serializable
339    val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
340    CaseWhen(cases, elseValue)
341  }
342}
343