1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.expressions 19 20import org.apache.spark.sql.catalyst.InternalRow 21import org.apache.spark.sql.catalyst.analysis.TypeCheckResult 22import org.apache.spark.sql.catalyst.expressions.codegen._ 23import org.apache.spark.sql.types._ 24 25// scalastyle:off line.size.limit 26@ExpressionDescription( 27 usage = "_FUNC_(expr1, expr2, expr3) - If `expr1` evaluates to true, then returns `expr2`; otherwise returns `expr3`.", 28 extended = """ 29 Examples: 30 > SELECT _FUNC_(1 < 2, 'a', 'b'); 31 a 32 """) 33// scalastyle:on line.size.limit 34case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) 35 extends Expression { 36 37 override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil 38 override def nullable: Boolean = trueValue.nullable || falseValue.nullable 39 40 override def checkInputDataTypes(): TypeCheckResult = { 41 if (predicate.dataType != BooleanType) { 42 TypeCheckResult.TypeCheckFailure( 43 s"type of predicate expression in If should be boolean, not ${predicate.dataType}") 44 } else if (!trueValue.dataType.sameType(falseValue.dataType)) { 45 TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + 46 s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") 47 } else { 48 TypeCheckResult.TypeCheckSuccess 49 } 50 } 51 52 override def dataType: DataType = trueValue.dataType 53 54 override def eval(input: InternalRow): Any = { 55 if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { 56 trueValue.eval(input) 57 } else { 58 falseValue.eval(input) 59 } 60 } 61 62 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 63 val condEval = predicate.genCode(ctx) 64 val trueEval = trueValue.genCode(ctx) 65 val falseEval = falseValue.genCode(ctx) 66 67 // place generated code of condition, true value and false value in separate methods if 68 // their code combined is large 69 val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length 70 val generatedCode = if (combinedLength > 1024 && 71 // Split these expressions only if they are created from a row object 72 (ctx.INPUT_ROW != null && ctx.currentVars == null)) { 73 74 val (condFuncName, condGlobalIsNull, condGlobalValue) = 75 createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr") 76 val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = 77 createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr") 78 val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = 79 createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr") 80 s""" 81 $condFuncName(${ctx.INPUT_ROW}); 82 boolean ${ev.isNull} = false; 83 ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; 84 if (!$condGlobalIsNull && $condGlobalValue) { 85 $trueFuncName(${ctx.INPUT_ROW}); 86 ${ev.isNull} = $trueGlobalIsNull; 87 ${ev.value} = $trueGlobalValue; 88 } else { 89 $falseFuncName(${ctx.INPUT_ROW}); 90 ${ev.isNull} = $falseGlobalIsNull; 91 ${ev.value} = $falseGlobalValue; 92 } 93 """ 94 } 95 else { 96 s""" 97 ${condEval.code} 98 boolean ${ev.isNull} = false; 99 ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; 100 if (!${condEval.isNull} && ${condEval.value}) { 101 ${trueEval.code} 102 ${ev.isNull} = ${trueEval.isNull}; 103 ${ev.value} = ${trueEval.value}; 104 } else { 105 ${falseEval.code} 106 ${ev.isNull} = ${falseEval.isNull}; 107 ${ev.value} = ${falseEval.value}; 108 } 109 """ 110 } 111 112 ev.copy(code = generatedCode) 113 } 114 115 private def createAndAddFunction( 116 ctx: CodegenContext, 117 ev: ExprCode, 118 dataType: DataType, 119 baseFuncName: String): (String, String, String) = { 120 val globalIsNull = ctx.freshName("isNull") 121 ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") 122 val globalValue = ctx.freshName("value") 123 ctx.addMutableState(ctx.javaType(dataType), globalValue, 124 s"$globalValue = ${ctx.defaultValue(dataType)};") 125 val funcName = ctx.freshName(baseFuncName) 126 val funcBody = 127 s""" 128 |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { 129 | ${ev.code.trim} 130 | $globalIsNull = ${ev.isNull}; 131 | $globalValue = ${ev.value}; 132 |} 133 """.stripMargin 134 ctx.addNewFunction(funcName, funcBody) 135 (funcName, globalIsNull, globalValue) 136 } 137 138 override def toString: String = s"if ($predicate) $trueValue else $falseValue" 139 140 override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" 141} 142 143/** 144 * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen. 145 * 146 * @param branches seq of (branch condition, branch value) 147 * @param elseValue optional value for the else branch 148 */ 149abstract class CaseWhenBase( 150 branches: Seq[(Expression, Expression)], 151 elseValue: Option[Expression]) 152 extends Expression with Serializable { 153 154 override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue 155 156 // both then and else expressions should be considered. 157 def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) 158 159 def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { 160 case Seq(dt1, dt2) => dt1.sameType(dt2) 161 } 162 163 override def dataType: DataType = branches.head._2.dataType 164 165 override def nullable: Boolean = { 166 // Result is nullable if any of the branch is nullable, or if the else value is nullable 167 branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) 168 } 169 170 override def checkInputDataTypes(): TypeCheckResult = { 171 // Make sure all branch conditions are boolean types. 172 if (valueTypesEqual) { 173 if (branches.forall(_._1.dataType == BooleanType)) { 174 TypeCheckResult.TypeCheckSuccess 175 } else { 176 val index = branches.indexWhere(_._1.dataType != BooleanType) 177 TypeCheckResult.TypeCheckFailure( 178 s"WHEN expressions in CaseWhen should all be boolean type, " + 179 s"but the ${index + 1}th when expression's type is ${branches(index)._1}") 180 } 181 } else { 182 TypeCheckResult.TypeCheckFailure( 183 "THEN and ELSE expressions should all be same type or coercible to a common type") 184 } 185 } 186 187 override def eval(input: InternalRow): Any = { 188 var i = 0 189 val size = branches.size 190 while (i < size) { 191 if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) { 192 return branches(i)._2.eval(input) 193 } 194 i += 1 195 } 196 if (elseValue.isDefined) { 197 return elseValue.get.eval(input) 198 } else { 199 return null 200 } 201 } 202 203 override def toString: String = { 204 val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString 205 val elseCase = elseValue.map(" ELSE " + _).getOrElse("") 206 "CASE" + cases + elseCase + " END" 207 } 208 209 override def sql: String = { 210 val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString 211 val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") 212 "CASE" + cases + elseCase + " END" 213 } 214} 215 216 217/** 218 * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". 219 * When a = true, returns b; when c = true, returns d; else returns e. 220 * 221 * @param branches seq of (branch condition, branch value) 222 * @param elseValue optional value for the else branch 223 */ 224// scalastyle:off line.size.limit 225@ExpressionDescription( 226 usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; when `expr3` = true, return `expr4`; else return `expr5`.") 227// scalastyle:on line.size.limit 228case class CaseWhen( 229 val branches: Seq[(Expression, Expression)], 230 val elseValue: Option[Expression] = None) 231 extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable { 232 233 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 234 super[CodegenFallback].doGenCode(ctx, ev) 235 } 236 237 def toCodegen(): CaseWhenCodegen = { 238 CaseWhenCodegen(branches, elseValue) 239 } 240} 241 242/** 243 * CaseWhen expression used when code generation condition is satisfied. 244 * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen. 245 * 246 * @param branches seq of (branch condition, branch value) 247 * @param elseValue optional value for the else branch 248 */ 249case class CaseWhenCodegen( 250 val branches: Seq[(Expression, Expression)], 251 val elseValue: Option[Expression] = None) 252 extends CaseWhenBase(branches, elseValue) with Serializable { 253 254 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { 255 // Generate code that looks like: 256 // 257 // condA = ... 258 // if (condA) { 259 // valueA 260 // } else { 261 // condB = ... 262 // if (condB) { 263 // valueB 264 // } else { 265 // condC = ... 266 // if (condC) { 267 // valueC 268 // } else { 269 // elseValue 270 // } 271 // } 272 // } 273 val cases = branches.map { case (condExpr, valueExpr) => 274 val cond = condExpr.genCode(ctx) 275 val res = valueExpr.genCode(ctx) 276 s""" 277 ${cond.code} 278 if (!${cond.isNull} && ${cond.value}) { 279 ${res.code} 280 ${ev.isNull} = ${res.isNull}; 281 ${ev.value} = ${res.value}; 282 } 283 """ 284 } 285 286 var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") 287 288 elseValue.foreach { elseExpr => 289 val res = elseExpr.genCode(ctx) 290 generatedCode += 291 s""" 292 ${res.code} 293 ${ev.isNull} = ${res.isNull}; 294 ${ev.value} = ${res.value}; 295 """ 296 } 297 298 generatedCode += "}\n" * cases.size 299 300 ev.copy(code = s""" 301 boolean ${ev.isNull} = true; 302 ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; 303 $generatedCode""") 304 } 305} 306 307/** Factory methods for CaseWhen. */ 308object CaseWhen { 309 def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { 310 CaseWhen(branches, Option(elseValue)) 311 } 312 313 /** 314 * A factory method to facilitate the creation of this expression when used in parsers. 315 * 316 * @param branches Expressions at even position are the branch conditions, and expressions at odd 317 * position are branch values. 318 */ 319 def createFromParser(branches: Seq[Expression]): CaseWhen = { 320 val cases = branches.grouped(2).flatMap { 321 case cond :: value :: Nil => Some((cond, value)) 322 case value :: Nil => None 323 }.toArray.toSeq // force materialization to make the seq serializable 324 val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None 325 CaseWhen(cases, elseValue) 326 } 327} 328 329/** 330 * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". 331 * When a = b, returns c; when a = d, returns e; else returns f. 332 */ 333object CaseKeyWhen { 334 def apply(key: Expression, branches: Seq[Expression]): CaseWhen = { 335 val cases = branches.grouped(2).flatMap { 336 case cond :: value :: Nil => Some((EqualTo(key, cond), value)) 337 case value :: Nil => None 338 }.toArray.toSeq // force materialization to make the seq serializable 339 val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None 340 CaseWhen(cases, elseValue) 341 } 342} 343