1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.analysis 19 20import scala.annotation.tailrec 21import scala.collection.mutable.ArrayBuffer 22 23import org.apache.spark.sql.AnalysisException 24import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} 25import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} 26import org.apache.spark.sql.catalyst.encoders.OuterScopes 27import org.apache.spark.sql.catalyst.expressions._ 28import org.apache.spark.sql.catalyst.expressions.aggregate._ 29import org.apache.spark.sql.catalyst.expressions.objects.NewInstance 30import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification 31import org.apache.spark.sql.catalyst.plans._ 32import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} 33import org.apache.spark.sql.catalyst.rules._ 34import org.apache.spark.sql.catalyst.trees.TreeNodeRef 35import org.apache.spark.sql.catalyst.util.toPrettySQL 36import org.apache.spark.sql.types._ 37 38/** 39 * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. 40 * Used for testing when all relations are already filled in and the analyzer needs only 41 * to resolve attribute references. 42 */ 43object SimpleAnalyzer extends Analyzer( 44 new SessionCatalog( 45 new InMemoryCatalog, 46 EmptyFunctionRegistry, 47 new SimpleCatalystConf(caseSensitiveAnalysis = true)) { 48 override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} 49 }, 50 new SimpleCatalystConf(caseSensitiveAnalysis = true)) 51 52/** 53 * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and 54 * [[UnresolvedRelation]]s into fully typed objects using information in a 55 * [[SessionCatalog]] and a [[FunctionRegistry]]. 56 */ 57class Analyzer( 58 catalog: SessionCatalog, 59 conf: CatalystConf, 60 maxIterations: Int) 61 extends RuleExecutor[LogicalPlan] with CheckAnalysis { 62 63 def this(catalog: SessionCatalog, conf: CatalystConf) = { 64 this(catalog, conf, conf.optimizerMaxIterations) 65 } 66 67 def resolver: Resolver = conf.resolver 68 69 protected val fixedPoint = FixedPoint(maxIterations) 70 71 /** 72 * Override to provide additional rules for the "Resolution" batch. 73 */ 74 val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil 75 76 lazy val batches: Seq[Batch] = Seq( 77 Batch("Substitution", fixedPoint, 78 CTESubstitution, 79 WindowsSubstitution, 80 EliminateUnions, 81 new SubstituteUnresolvedOrdinals(conf)), 82 Batch("Resolution", fixedPoint, 83 ResolveTableValuedFunctions :: 84 ResolveRelations :: 85 ResolveReferences :: 86 ResolveCreateNamedStruct :: 87 ResolveDeserializer :: 88 ResolveNewInstance :: 89 ResolveUpCast :: 90 ResolveGroupingAnalytics :: 91 ResolvePivot :: 92 ResolveOrdinalInOrderByAndGroupBy :: 93 ResolveMissingReferences :: 94 ExtractGenerator :: 95 ResolveGenerate :: 96 ResolveFunctions :: 97 ResolveAliases :: 98 ResolveSubquery :: 99 ResolveWindowOrder :: 100 ResolveWindowFrame :: 101 ResolveNaturalAndUsingJoin :: 102 ExtractWindowExpressions :: 103 GlobalAggregates :: 104 ResolveAggregateFunctions :: 105 TimeWindowing :: 106 ResolveInlineTables :: 107 TypeCoercion.typeCoercionRules ++ 108 extendedResolutionRules : _*), 109 Batch("Nondeterministic", Once, 110 PullOutNondeterministic), 111 Batch("UDF", Once, 112 HandleNullInputsForUDF), 113 Batch("FixNullability", Once, 114 FixNullability), 115 Batch("Cleanup", fixedPoint, 116 CleanupAliases) 117 ) 118 119 /** 120 * Analyze cte definitions and substitute child plan with analyzed cte definitions. 121 */ 122 object CTESubstitution extends Rule[LogicalPlan] { 123 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 124 case With(child, relations) => 125 substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { 126 case (resolved, (name, relation)) => 127 resolved :+ name -> execute(substituteCTE(relation, resolved)) 128 }) 129 case other => other 130 } 131 132 def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { 133 plan transformDown { 134 case u : UnresolvedRelation => 135 val substituted = cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) 136 .map(_._2).map { relation => 137 val withAlias = u.alias.map(SubqueryAlias(_, relation, None)) 138 withAlias.getOrElse(relation) 139 } 140 substituted.getOrElse(u) 141 case other => 142 // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. 143 other transformExpressions { 144 case e: SubqueryExpression => 145 e.withNewPlan(substituteCTE(e.plan, cteRelations)) 146 } 147 } 148 } 149 } 150 151 /** 152 * Substitute child plan with WindowSpecDefinitions. 153 */ 154 object WindowsSubstitution extends Rule[LogicalPlan] { 155 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 156 // Lookup WindowSpecDefinitions. This rule works with unresolved children. 157 case WithWindowDefinition(windowDefinitions, child) => 158 child.transform { 159 case p => p.transformExpressions { 160 case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => 161 val errorMessage = 162 s"Window specification $windowName is not defined in the WINDOW clause." 163 val windowSpecDefinition = 164 windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) 165 WindowExpression(c, windowSpecDefinition) 166 } 167 } 168 } 169 } 170 171 /** 172 * Replaces [[UnresolvedAlias]]s with concrete aliases. 173 */ 174 object ResolveAliases extends Rule[LogicalPlan] { 175 private def assignAliases(exprs: Seq[NamedExpression]) = { 176 exprs.zipWithIndex.map { 177 case (expr, i) => 178 expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => 179 child match { 180 case ne: NamedExpression => ne 181 case e if !e.resolved => u 182 case g: Generator => MultiAlias(g, Nil) 183 case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() 184 case e: ExtractValue => Alias(e, toPrettySQL(e))() 185 case e if optGenAliasFunc.isDefined => 186 Alias(child, optGenAliasFunc.get.apply(e))() 187 case e => Alias(e, toPrettySQL(e))() 188 } 189 } 190 }.asInstanceOf[Seq[NamedExpression]] 191 } 192 193 private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = 194 exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) 195 196 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 197 case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => 198 Aggregate(groups, assignAliases(aggs), child) 199 200 case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => 201 g.copy(aggregations = assignAliases(g.aggregations)) 202 203 case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) 204 if child.resolved && hasUnresolvedAlias(groupByExprs) => 205 Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) 206 207 case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => 208 Project(assignAliases(projectList), child) 209 } 210 } 211 212 object ResolveGroupingAnalytics extends Rule[LogicalPlan] { 213 /* 214 * GROUP BY a, b, c WITH ROLLUP 215 * is equivalent to 216 * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( ) ). 217 * Group Count: N + 1 (N is the number of group expressions) 218 * 219 * We need to get all of its subsets for the rule described above, the subset is 220 * represented as the bit masks. 221 */ 222 def bitmasks(r: Rollup): Seq[Int] = { 223 Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1) 224 } 225 226 /* 227 * GROUP BY a, b, c WITH CUBE 228 * is equivalent to 229 * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ). 230 * Group Count: 2 ^ N (N is the number of group expressions) 231 * 232 * We need to get all of its subsets for a given GROUPBY expression, the subsets are 233 * represented as the bit masks. 234 */ 235 def bitmasks(c: Cube): Seq[Int] = { 236 Seq.tabulate(1 << c.groupByExprs.length)(i => i) 237 } 238 239 private def hasGroupingAttribute(expr: Expression): Boolean = { 240 expr.collectFirst { 241 case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u 242 }.isDefined 243 } 244 245 private[analysis] def hasGroupingFunction(e: Expression): Boolean = { 246 e.collectFirst { 247 case g: Grouping => g 248 case g: GroupingID => g 249 }.isDefined 250 } 251 252 private def replaceGroupingFunc( 253 expr: Expression, 254 groupByExprs: Seq[Expression], 255 gid: Expression): Expression = { 256 expr transform { 257 case e: GroupingID => 258 if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { 259 gid 260 } else { 261 throw new AnalysisException( 262 s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + 263 s"grouping columns (${groupByExprs.mkString(",")})") 264 } 265 case Grouping(col: Expression) => 266 val idx = groupByExprs.indexOf(col) 267 if (idx >= 0) { 268 Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), 269 Literal(1)), ByteType) 270 } else { 271 throw new AnalysisException(s"Column of grouping ($col) can't be found " + 272 s"in grouping columns ${groupByExprs.mkString(",")}") 273 } 274 } 275 } 276 277 // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort 278 def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { 279 case a if !a.childrenResolved => a // be sure all of the children are resolved. 280 case p if p.expressions.exists(hasGroupingAttribute) => 281 failAnalysis( 282 s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") 283 284 case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => 285 GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) 286 case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => 287 GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) 288 289 // Ensure all the expressions have been resolved. 290 case x: GroupingSets if x.expressions.forall(_.resolved) => 291 val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() 292 293 // Expand works by setting grouping expressions to null as determined by the bitmasks. To 294 // prevent these null values from being used in an aggregate instead of the original value 295 // we need to create new aliases for all group by expressions that will only be used for 296 // the intended purpose. 297 val groupByAliases: Seq[Alias] = x.groupByExprs.map { 298 case e: NamedExpression => Alias(e, e.name)() 299 case other => Alias(other, other.toString)() 300 } 301 302 // The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases 303 // with 0 indicating this expression is in the grouping set. The following line of code 304 // calculates the bitmask representing the expressions that absent in at least one grouping 305 // set (indicated by 1). 306 val nullBitmask = x.bitmasks.reduce(_ | _) 307 308 val attrLength = groupByAliases.length 309 val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => 310 val canBeNull = ((nullBitmask >> (attrLength - idx - 1)) & 1) == 1 311 a.toAttribute.withNullability(a.nullable || canBeNull) 312 } 313 314 val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child) 315 val groupingAttrs = expand.output.drop(x.child.output.length) 316 317 val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => 318 // collect all the found AggregateExpression, so we can check an expression is part of 319 // any AggregateExpression or not. 320 val aggsBuffer = ArrayBuffer[Expression]() 321 // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. 322 def isPartOfAggregation(e: Expression): Boolean = { 323 aggsBuffer.exists(a => a.find(_ eq e).isDefined) 324 } 325 replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown { 326 // AggregateExpression should be computed on the unmodified value of its argument 327 // expressions, so we should not replace any references to grouping expression 328 // inside it. 329 case e: AggregateExpression => 330 aggsBuffer += e 331 e 332 case e if isPartOfAggregation(e) => e 333 case e => 334 val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) 335 if (index == -1) { 336 e 337 } else { 338 groupingAttrs(index) 339 } 340 }.asInstanceOf[NamedExpression] 341 } 342 343 Aggregate(groupingAttrs, aggregations, expand) 344 345 case f @ Filter(cond, child) if hasGroupingFunction(cond) => 346 val groupingExprs = findGroupingExprs(child) 347 // The unresolved grouping id will be resolved by ResolveMissingReferences 348 val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) 349 f.copy(condition = newCond) 350 351 case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) => 352 val groupingExprs = findGroupingExprs(child) 353 val gid = VirtualColumn.groupingIdAttribute 354 // The unresolved grouping id will be resolved by ResolveMissingReferences 355 val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) 356 s.copy(order = newOrder) 357 } 358 359 private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { 360 plan.collectFirst { 361 case a: Aggregate => 362 // this Aggregate should have grouping id as the last grouping key. 363 val gid = a.groupingExpressions.last 364 if (!gid.isInstanceOf[AttributeReference] 365 || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { 366 failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") 367 } 368 a.groupingExpressions.take(a.groupingExpressions.length - 1) 369 }.getOrElse { 370 failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") 371 } 372 } 373 } 374 375 object ResolvePivot extends Rule[LogicalPlan] { 376 def apply(plan: LogicalPlan): LogicalPlan = plan transform { 377 case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) 378 | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p 379 case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => 380 val singleAgg = aggregates.size == 1 381 def outputName(value: Literal, aggregate: Expression): String = { 382 if (singleAgg) { 383 value.toString 384 } else { 385 val suffix = aggregate match { 386 case n: NamedExpression => n.name 387 case _ => toPrettySQL(aggregate) 388 } 389 value + "_" + suffix 390 } 391 } 392 if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { 393 // Since evaluating |pivotValues| if statements for each input row can get slow this is an 394 // alternate plan that instead uses two steps of aggregation. 395 val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) 396 val namedPivotCol = pivotColumn match { 397 case n: NamedExpression => n 398 case _ => Alias(pivotColumn, "__pivot_col")() 399 } 400 val bigGroup = groupByExprs :+ namedPivotCol 401 val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) 402 val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) 403 val pivotAggs = namedAggExps.map { a => 404 Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) 405 .toAggregateExpression() 406 , "__pivot_" + a.sql)() 407 } 408 val groupByExprsAttr = groupByExprs.map(_.toAttribute) 409 val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) 410 val pivotAggAttribute = pivotAggs.map(_.toAttribute) 411 val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => 412 aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => 413 Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() 414 } 415 } 416 Project(groupByExprsAttr ++ pivotOutputs, secondAgg) 417 } else { 418 val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => 419 def ifExpr(expr: Expression) = { 420 If(EqualTo(pivotColumn, value), expr, Literal(null)) 421 } 422 aggregates.map { aggregate => 423 val filteredAggregate = aggregate.transformDown { 424 // Assumption is the aggregate function ignores nulls. This is true for all current 425 // AggregateFunction's with the exception of First and Last in their default mode 426 // (which we handle) and possibly some Hive UDAF's. 427 case First(expr, _) => 428 First(ifExpr(expr), Literal(true)) 429 case Last(expr, _) => 430 Last(ifExpr(expr), Literal(true)) 431 case a: AggregateFunction => 432 a.withNewChildren(a.children.map(ifExpr)) 433 }.transform { 434 // We are duplicating aggregates that are now computing a different value for each 435 // pivot value. 436 // TODO: Don't construct the physical container until after analysis. 437 case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) 438 } 439 if (filteredAggregate.fastEquals(aggregate)) { 440 throw new AnalysisException( 441 s"Aggregate expression required for pivot, found '$aggregate'") 442 } 443 Alias(filteredAggregate, outputName(value, aggregate))() 444 } 445 } 446 Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) 447 } 448 } 449 } 450 451 /** 452 * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. 453 */ 454 object ResolveRelations extends Rule[LogicalPlan] { 455 private def lookupTableFromCatalog(u: UnresolvedRelation): LogicalPlan = { 456 try { 457 catalog.lookupRelation(u.tableIdentifier, u.alias) 458 } catch { 459 case _: NoSuchTableException => 460 u.failAnalysis(s"Table or view not found: ${u.tableName}") 461 } 462 } 463 464 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 465 case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => 466 i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) 467 case u: UnresolvedRelation => 468 val table = u.tableIdentifier 469 if (table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) && 470 (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))) { 471 // If the database part is specified, and we support running SQL directly on files, and 472 // it's not a temporary view, and the table does not exist, then let's just return the 473 // original UnresolvedRelation. It is possible we are matching a query like "select * 474 // from parquet.`/path/to/query`". The plan will get resolved later. 475 // Note that we are testing (!db_exists || !table_exists) because the catalog throws 476 // an exception from tableExists if the database does not exist. 477 u 478 } else { 479 lookupTableFromCatalog(u) 480 } 481 } 482 } 483 484 /** 485 * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from 486 * a logical plan node's children. 487 */ 488 object ResolveReferences extends Rule[LogicalPlan] { 489 /** 490 * Generate a new logical plan for the right child with different expression IDs 491 * for all conflicting attributes. 492 */ 493 private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { 494 val conflictingAttributes = left.outputSet.intersect(right.outputSet) 495 logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + 496 s"between $left and $right") 497 498 right.collect { 499 // Handle base relations that might appear more than once. 500 case oldVersion: MultiInstanceRelation 501 if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => 502 val newVersion = oldVersion.newInstance() 503 (oldVersion, newVersion) 504 505 case oldVersion: SerializeFromObject 506 if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => 507 (oldVersion, oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))) 508 509 // Handle projects that create conflicting aliases. 510 case oldVersion @ Project(projectList, _) 511 if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => 512 (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) 513 514 case oldVersion @ Aggregate(_, aggregateExpressions, _) 515 if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => 516 (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) 517 518 case oldVersion: Generate 519 if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => 520 val newOutput = oldVersion.generatorOutput.map(_.newInstance()) 521 (oldVersion, oldVersion.copy(generatorOutput = newOutput)) 522 523 case oldVersion @ Window(windowExpressions, _, _, child) 524 if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) 525 .nonEmpty => 526 (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) 527 } 528 // Only handle first case, others will be fixed on the next pass. 529 .headOption match { 530 case None => 531 /* 532 * No result implies that there is a logical plan node that produces new references 533 * that this rule cannot handle. When that is the case, there must be another rule 534 * that resolves these conflicts. Otherwise, the analysis will fail. 535 */ 536 right 537 case Some((oldRelation, newRelation)) => 538 val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) 539 val newRight = right transformUp { 540 case r if r == oldRelation => newRelation 541 } transformUp { 542 case other => other transformExpressions { 543 case a: Attribute => 544 attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) 545 } 546 } 547 newRight 548 } 549 } 550 551 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 552 case p: LogicalPlan if !p.childrenResolved => p 553 554 // If the projection list contains Stars, expand it. 555 case p: Project if containsStar(p.projectList) => 556 p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) 557 // If the aggregate function argument contains Stars, expand it. 558 case a: Aggregate if containsStar(a.aggregateExpressions) => 559 if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { 560 failAnalysis( 561 "Star (*) is not allowed in select list when GROUP BY ordinal position is used") 562 } else { 563 a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) 564 } 565 // If the script transformation input contains Stars, expand it. 566 case t: ScriptTransformation if containsStar(t.input) => 567 t.copy( 568 input = t.input.flatMap { 569 case s: Star => s.expand(t.child, resolver) 570 case o => o :: Nil 571 } 572 ) 573 case g: Generate if containsStar(g.generator.children) => 574 failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") 575 576 // To resolve duplicate expression IDs for Join and Intersect 577 case j @ Join(left, right, _, _) if !j.duplicateResolved => 578 j.copy(right = dedupRight(left, right)) 579 case i @ Intersect(left, right) if !i.duplicateResolved => 580 i.copy(right = dedupRight(left, right)) 581 case i @ Except(left, right) if !i.duplicateResolved => 582 i.copy(right = dedupRight(left, right)) 583 584 // When resolve `SortOrder`s in Sort based on child, don't report errors as 585 // we still have chance to resolve it based on its descendants 586 case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => 587 val newOrdering = 588 ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) 589 Sort(newOrdering, global, child) 590 591 // A special case for Generate, because the output of Generate should not be resolved by 592 // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. 593 case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g 594 595 case g @ Generate(generator, join, outer, qualifier, output, child) => 596 val newG = resolveExpression(generator, child, throws = true) 597 if (newG.fastEquals(generator)) { 598 g 599 } else { 600 Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) 601 } 602 603 // Skips plan which contains deserializer expressions, as they should be resolved by another 604 // rule: ResolveDeserializer. 605 case plan if containsDeserializer(plan.expressions) => plan 606 607 case q: LogicalPlan => 608 logTrace(s"Attempting to resolve ${q.simpleString}") 609 q transformExpressionsUp { 610 case u @ UnresolvedAttribute(nameParts) => 611 // Leave unchanged if resolution fails. Hopefully will be resolved next round. 612 val result = 613 withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } 614 logDebug(s"Resolving $u to $result") 615 result 616 case UnresolvedExtractValue(child, fieldExpr) if child.resolved => 617 ExtractValue(child, fieldExpr, resolver) 618 } 619 } 620 621 def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { 622 expressions.map { 623 case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) 624 case other => other 625 } 626 } 627 628 def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { 629 AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) 630 } 631 632 /** 633 * Build a project list for Project/Aggregate and expand the star if possible 634 */ 635 private def buildExpandedProjectList( 636 exprs: Seq[NamedExpression], 637 child: LogicalPlan): Seq[NamedExpression] = { 638 exprs.flatMap { 639 // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") 640 case s: Star => s.expand(child, resolver) 641 // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b 642 case UnresolvedAlias(s: Star, _) => s.expand(child, resolver) 643 case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil 644 case o => o :: Nil 645 }.map(_.asInstanceOf[NamedExpression]) 646 } 647 648 /** 649 * Returns true if `exprs` contains a [[Star]]. 650 */ 651 def containsStar(exprs: Seq[Expression]): Boolean = 652 exprs.exists(_.collect { case _: Star => true }.nonEmpty) 653 654 /** 655 * Expands the matching attribute.*'s in `child`'s output. 656 */ 657 def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { 658 expr.transformUp { 659 case f1: UnresolvedFunction if containsStar(f1.children) => 660 f1.copy(children = f1.children.flatMap { 661 case s: Star => s.expand(child, resolver) 662 case o => o :: Nil 663 }) 664 case c: CreateNamedStruct if containsStar(c.valExprs) => 665 val newChildren = c.children.grouped(2).flatMap { 666 case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children 667 case kv => kv 668 } 669 c.copy(children = newChildren.toList ) 670 case c: CreateArray if containsStar(c.children) => 671 c.copy(children = c.children.flatMap { 672 case s: Star => s.expand(child, resolver) 673 case o => o :: Nil 674 }) 675 case p: Murmur3Hash if containsStar(p.children) => 676 p.copy(children = p.children.flatMap { 677 case s: Star => s.expand(child, resolver) 678 case o => o :: Nil 679 }) 680 // count(*) has been replaced by count(1) 681 case o if containsStar(o.children) => 682 failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") 683 } 684 } 685 } 686 687 private def containsDeserializer(exprs: Seq[Expression]): Boolean = { 688 exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) 689 } 690 691 protected[sql] def resolveExpression( 692 expr: Expression, 693 plan: LogicalPlan, 694 throws: Boolean = false) = { 695 // Resolve expression in one round. 696 // If throws == false or the desired attribute doesn't exist 697 // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. 698 // Else, throw exception. 699 try { 700 expr transformUp { 701 case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) 702 case u @ UnresolvedAttribute(nameParts) => 703 withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } 704 case UnresolvedExtractValue(child, fieldName) if child.resolved => 705 ExtractValue(child, fieldName, resolver) 706 } 707 } catch { 708 case a: AnalysisException if !throws => expr 709 } 710 } 711 712 /** 713 * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by 714 * clauses. This rule is to convert ordinal positions to the corresponding expressions in the 715 * select list. This support is introduced in Spark 2.0. 716 * 717 * - When the sort references or group by expressions are not integer but foldable expressions, 718 * just ignore them. 719 * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position 720 * numbers too. 721 * 722 * Before the release of Spark 2.0, the literals in order/sort by and group by clauses 723 * have no effect on the results. 724 */ 725 object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { 726 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 727 case p if !p.childrenResolved => p 728 // Replace the index with the related attribute for ORDER BY, 729 // which is a 1-base position of the projection list. 730 case Sort(orders, global, child) 731 if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => 732 val newOrders = orders map { 733 case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => 734 if (index > 0 && index <= child.output.size) { 735 SortOrder(child.output(index - 1), direction, nullOrdering) 736 } else { 737 s.failAnalysis( 738 s"ORDER BY position $index is not in select list " + 739 s"(valid range is [1, ${child.output.size}])") 740 } 741 case o => o 742 } 743 Sort(newOrders, global, child) 744 745 // Replace the index with the corresponding expression in aggregateExpressions. The index is 746 // a 1-base position of aggregateExpressions, which is output columns (select expression) 747 case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && 748 groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => 749 val newGroups = groups.map { 750 case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => 751 aggs(index - 1) 752 case ordinal @ UnresolvedOrdinal(index) => 753 ordinal.failAnalysis( 754 s"GROUP BY position $index is not in select list " + 755 s"(valid range is [1, ${aggs.size}])") 756 case o => o 757 } 758 Aggregate(newGroups, aggs, child) 759 } 760 } 761 762 /** 763 * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT 764 * clause. This rule detects such queries and adds the required attributes to the original 765 * projection, so that they will be available during sorting. Another projection is added to 766 * remove these attributes after sorting. 767 * 768 * The HAVING clause could also used a grouping columns that is not presented in the SELECT. 769 */ 770 object ResolveMissingReferences extends Rule[LogicalPlan] { 771 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 772 // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions 773 case sa @ Sort(_, _, child: Aggregate) => sa 774 775 case s @ Sort(order, _, child) if child.resolved => 776 try { 777 val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) 778 val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) 779 val missingAttrs = requiredAttrs -- child.outputSet 780 if (missingAttrs.nonEmpty) { 781 // Add missing attributes and then project them away after the sort. 782 Project(child.output, 783 Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) 784 } else if (newOrder != order) { 785 s.copy(order = newOrder) 786 } else { 787 s 788 } 789 } catch { 790 // Attempting to resolve it might fail. When this happens, return the original plan. 791 // Users will see an AnalysisException for resolution failure of missing attributes 792 // in Sort 793 case ae: AnalysisException => s 794 } 795 796 case f @ Filter(cond, child) if child.resolved => 797 try { 798 val newCond = resolveExpressionRecursively(cond, child) 799 val requiredAttrs = newCond.references.filter(_.resolved) 800 val missingAttrs = requiredAttrs -- child.outputSet 801 if (missingAttrs.nonEmpty) { 802 // Add missing attributes and then project them away. 803 Project(child.output, 804 Filter(newCond, addMissingAttr(child, missingAttrs))) 805 } else if (newCond != cond) { 806 f.copy(condition = newCond) 807 } else { 808 f 809 } 810 } catch { 811 // Attempting to resolve it might fail. When this happens, return the original plan. 812 // Users will see an AnalysisException for resolution failure of missing attributes 813 case ae: AnalysisException => f 814 } 815 } 816 817 /** 818 * Add the missing attributes into projectList of Project/Window or aggregateExpressions of 819 * Aggregate. 820 */ 821 private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { 822 if (missingAttrs.isEmpty) { 823 return plan 824 } 825 plan match { 826 case p: Project => 827 val missing = missingAttrs -- p.child.outputSet 828 Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) 829 case a: Aggregate => 830 // all the missing attributes should be grouping expressions 831 // TODO: push down AggregateExpression 832 missingAttrs.foreach { attr => 833 if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { 834 throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") 835 } 836 } 837 val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs 838 a.copy(aggregateExpressions = newAggregateExpressions) 839 case g: Generate => 840 // If join is false, we will convert it to true for getting from the child the missing 841 // attributes that its child might have or could have. 842 val missing = missingAttrs -- g.child.outputSet 843 g.copy(join = true, child = addMissingAttr(g.child, missing)) 844 case d: Distinct => 845 throw new AnalysisException(s"Can't add $missingAttrs to $d") 846 case u: UnaryNode => 847 u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) 848 case other => 849 throw new AnalysisException(s"Can't add $missingAttrs to $other") 850 } 851 } 852 853 /** 854 * Resolve the expression on a specified logical plan and it's child (recursively), until 855 * the expression is resolved or meet a non-unary node or Subquery. 856 */ 857 @tailrec 858 private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { 859 val resolved = resolveExpression(expr, plan) 860 if (resolved.resolved) { 861 resolved 862 } else { 863 plan match { 864 case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => 865 resolveExpressionRecursively(resolved, u.child) 866 case other => resolved 867 } 868 } 869 } 870 } 871 872 /** 873 * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. 874 */ 875 object ResolveFunctions extends Rule[LogicalPlan] { 876 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 877 case q: LogicalPlan => 878 q transformExpressions { 879 case u if !u.childrenResolved => u // Skip until children are resolved. 880 case u @ UnresolvedGenerator(name, children) => 881 withPosition(u) { 882 catalog.lookupFunction(name, children) match { 883 case generator: Generator => generator 884 case other => 885 failAnalysis(s"$name is expected to be a generator. However, " + 886 s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") 887 } 888 } 889 case u @ UnresolvedFunction(funcId, children, isDistinct) => 890 withPosition(u) { 891 catalog.lookupFunction(funcId, children) match { 892 // DISTINCT is not meaningful for a Max or a Min. 893 case max: Max if isDistinct => 894 AggregateExpression(max, Complete, isDistinct = false) 895 case min: Min if isDistinct => 896 AggregateExpression(min, Complete, isDistinct = false) 897 // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within 898 // the context of a Window clause. They do not need to be wrapped in an 899 // AggregateExpression. 900 case wf: AggregateWindowFunction => wf 901 // We get an aggregate function, we need to wrap it in an AggregateExpression. 902 case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) 903 // This function is not an aggregate function, just return the resolved one. 904 case other => other 905 } 906 } 907 } 908 } 909 } 910 911 /** 912 * This rule resolves and rewrites subqueries inside expressions. 913 * 914 * Note: CTEs are handled in CTESubstitution. 915 */ 916 object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { 917 /** 918 * Resolve the correlated expressions in a subquery by using the an outer plans' references. All 919 * resolved outer references are wrapped in an [[OuterReference]] 920 */ 921 private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { 922 plan transformDown { 923 case q: LogicalPlan if q.childrenResolved && !q.resolved => 924 q transformExpressions { 925 case u @ UnresolvedAttribute(nameParts) => 926 withPosition(u) { 927 try { 928 outer.resolve(nameParts, resolver) match { 929 case Some(outerAttr) => OuterReference(outerAttr) 930 case None => u 931 } 932 } catch { 933 case _: AnalysisException => u 934 } 935 } 936 } 937 } 938 } 939 940 /** 941 * Pull out all (outer) correlated predicates from a given subquery. This method removes the 942 * correlated predicates from subquery [[Filter]]s and adds the references of these predicates 943 * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to 944 * be able to evaluate the predicates at the top level. 945 * 946 * This method returns the rewritten subquery and correlated predicates. 947 */ 948 private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { 949 val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] 950 951 // Make sure a plan's subtree does not contain outer references 952 def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { 953 if (p.collectFirst(predicateMap).nonEmpty) { 954 failAnalysis(s"Accessing outer query column is not allowed in:\n$p") 955 } 956 } 957 958 // Helper function for locating outer references. 959 def containsOuter(e: Expression): Boolean = { 960 e.find(_.isInstanceOf[OuterReference]).isDefined 961 } 962 963 // Make sure a plan's expressions do not contain outer references 964 def failOnOuterReference(p: LogicalPlan): Unit = { 965 if (p.expressions.exists(containsOuter)) { 966 failAnalysis( 967 "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + 968 s"clauses:\n$p") 969 } 970 } 971 972 // SPARK-17348: A potential incorrect result case. 973 // When a correlated predicate is a non-equality predicate, 974 // certain operators are not permitted from the operator 975 // hosting the correlated predicate up to the operator on the outer table. 976 // Otherwise, the pull up of the correlated predicate 977 // will generate a plan with a different semantics 978 // which could return incorrect result. 979 // Currently we check for Aggregate and Window operators 980 // 981 // Below shows an example of a Logical Plan during Analyzer phase that 982 // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] 983 // through the Aggregate (or Window) operator could alter the result of 984 // the Aggregate. 985 // 986 // Project [c1#76] 987 // +- Project [c1#87, c2#88] 988 // : (Aggregate or Window operator) 989 // : +- Filter [outer(c2#77) >= c2#88)] 990 // : +- SubqueryAlias t2, `t2` 991 // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] 992 // : +- LocalRelation [_1#84, _2#85] 993 // +- SubqueryAlias t1, `t1` 994 // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] 995 // +- LocalRelation [_1#73, _2#74] 996 def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { 997 if (found) { 998 // Report a non-supported case as an exception 999 failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") 1000 } 1001 } 1002 1003 /** Determine which correlated predicate references are missing from this plan. */ 1004 def missingReferences(p: LogicalPlan): AttributeSet = { 1005 val localPredicateReferences = p.collect(predicateMap) 1006 .flatten 1007 .map(_.references) 1008 .reduceOption(_ ++ _) 1009 .getOrElse(AttributeSet.empty) 1010 localPredicateReferences -- p.outputSet 1011 } 1012 1013 var foundNonEqualCorrelatedPred : Boolean = false 1014 1015 // Simplify the predicates before pulling them out. 1016 val transformed = BooleanSimplification(sub) transformUp { 1017 1018 // Whitelist operators allowed in a correlated subquery 1019 // There are 4 categories: 1020 // 1. Operators that are allowed anywhere in a correlated subquery, and, 1021 // by definition of the operators, they either do not contain 1022 // any columns or cannot host outer references. 1023 // 2. Operators that are allowed anywhere in a correlated subquery 1024 // so long as they do not host outer references. 1025 // 3. Operators that need special handlings. These operators are 1026 // Project, Filter, Join, Aggregate, and Generate. 1027 // 1028 // Any operators that are not in the above list are allowed 1029 // in a correlated subquery only if they are not on a correlation path. 1030 // In other word, these operators are allowed only under a correlation point. 1031 // 1032 // A correlation path is defined as the sub-tree of all the operators that 1033 // are on the path from the operator hosting the correlated expressions 1034 // up to the operator producing the correlated values. 1035 1036 // Category 1: 1037 // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias 1038 case p: BroadcastHint => 1039 p 1040 case p: Distinct => 1041 p 1042 case p: LeafNode => 1043 p 1044 case p: Repartition => 1045 p 1046 case p: SubqueryAlias => 1047 p 1048 1049 // Category 2: 1050 // These operators can be anywhere in a correlated subquery. 1051 // so long as they do not host outer references in the operators. 1052 case p: Sort => 1053 failOnOuterReference(p) 1054 p 1055 case p: RepartitionByExpression => 1056 failOnOuterReference(p) 1057 p 1058 1059 // Category 3: 1060 // Filter is one of the two operators allowed to host correlated expressions. 1061 // The other operator is Join. Filter can be anywhere in a correlated subquery. 1062 case f @ Filter(cond, child) => 1063 // Find all predicates with an outer reference. 1064 val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) 1065 1066 // Find any non-equality correlated predicates 1067 foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { 1068 case _: EqualTo | _: EqualNullSafe => false 1069 case _ => true 1070 } 1071 1072 // Rewrite the filter without the correlated predicates if any. 1073 correlated match { 1074 case Nil => f 1075 case xs if local.nonEmpty => 1076 val newFilter = Filter(local.reduce(And), child) 1077 predicateMap += newFilter -> xs 1078 newFilter 1079 case xs => 1080 predicateMap += child -> xs 1081 child 1082 } 1083 1084 // Project cannot host any correlated expressions 1085 // but can be anywhere in a correlated subquery. 1086 case p @ Project(expressions, child) => 1087 failOnOuterReference(p) 1088 1089 val referencesToAdd = missingReferences(p) 1090 if (referencesToAdd.nonEmpty) { 1091 Project(expressions ++ referencesToAdd, child) 1092 } else { 1093 p 1094 } 1095 1096 // Aggregate cannot host any correlated expressions 1097 // It can be on a correlation path if the correlation contains 1098 // only equality correlated predicates. 1099 // It cannot be on a correlation path if the correlation has 1100 // non-equality correlated predicates. 1101 case a @ Aggregate(grouping, expressions, child) => 1102 failOnOuterReference(a) 1103 failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) 1104 1105 val referencesToAdd = missingReferences(a) 1106 if (referencesToAdd.nonEmpty) { 1107 Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) 1108 } else { 1109 a 1110 } 1111 1112 // Join can host correlated expressions. 1113 case j @ Join(left, right, joinType, _) => 1114 joinType match { 1115 // Inner join, like Filter, can be anywhere. 1116 case _: InnerLike => 1117 failOnOuterReference(j) 1118 1119 // Left outer join's right operand cannot be on a correlation path. 1120 // LeftAnti and ExistenceJoin are special cases of LeftOuter. 1121 // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame 1122 // so it should not show up here in Analysis phase. This is just a safety net. 1123 // 1124 // LeftSemi does not allow output from the right operand. 1125 // Any correlated references in the subplan 1126 // of the right operand cannot be pulled up. 1127 case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => 1128 failOnOuterReference(j) 1129 failOnOuterReferenceInSubTree(right) 1130 1131 // Likewise, Right outer join's left operand cannot be on a correlation path. 1132 case RightOuter => 1133 failOnOuterReference(j) 1134 failOnOuterReferenceInSubTree(left) 1135 1136 // Any other join types not explicitly listed above, 1137 // including Full outer join, are treated as Category 4. 1138 case _ => 1139 failOnOuterReferenceInSubTree(j) 1140 } 1141 j 1142 1143 // Generator with join=true, i.e., expressed with 1144 // LATERAL VIEW [OUTER], similar to inner join, 1145 // allows to have correlation under it 1146 // but must not host any outer references. 1147 // Note: 1148 // Generator with join=false is treated as Category 4. 1149 case p @ Generate(generator, true, _, _, _, _) => 1150 failOnOuterReference(p) 1151 p 1152 1153 // Category 4: Any other operators not in the above 3 categories 1154 // cannot be on a correlation path, that is they are allowed only 1155 // under a correlation point but they and their descendant operators 1156 // are not allowed to have any correlated expressions. 1157 case p => 1158 failOnOuterReferenceInSubTree(p) 1159 p 1160 } 1161 (transformed, predicateMap.values.flatten.toSeq) 1162 } 1163 1164 /** 1165 * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same 1166 * attributes. 1167 */ 1168 private def rewriteSubQuery( 1169 sub: LogicalPlan, 1170 outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { 1171 // Pull out the tagged predicates and rewrite the subquery in the process. 1172 val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) 1173 1174 // Make sure the inner and the outer query attributes do not collide. 1175 val outputSet = outer.map(_.outputSet).reduce(_ ++ _) 1176 val duplicates = basePlan.outputSet.intersect(outputSet) 1177 val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { 1178 val aliasMap = AttributeMap(duplicates.map { dup => 1179 dup -> Alias(dup, dup.toString)() 1180 }.toSeq) 1181 val aliasedExpressions = basePlan.output.map { ref => 1182 aliasMap.getOrElse(ref, ref) 1183 } 1184 val aliasedProjection = Project(aliasedExpressions, basePlan) 1185 val aliasedConditions = baseConditions.map(_.transform { 1186 case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute 1187 }) 1188 (aliasedProjection, aliasedConditions) 1189 } else { 1190 (basePlan, baseConditions) 1191 } 1192 // Remove outer references from the correlated predicates. We wait with extracting 1193 // these until collisions between the inner and outer query attributes have been 1194 // solved. 1195 val conditions = deDuplicatedConditions.map(_.transform { 1196 case OuterReference(ref) => ref 1197 }) 1198 (plan, conditions) 1199 } 1200 1201 /** 1202 * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method 1203 * will resolve the subquery by alternating between the regular analyzer and by applying the 1204 * resolveOuterReferences rule. 1205 * 1206 * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. 1207 */ 1208 private def resolveSubQuery( 1209 e: SubqueryExpression, 1210 plans: Seq[LogicalPlan], 1211 requiredColumns: Int = 0)( 1212 f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { 1213 // Step 1: Resolve the outer expressions. 1214 var previous: LogicalPlan = null 1215 var current = e.plan 1216 do { 1217 // Try to resolve the subquery plan using the regular analyzer. 1218 previous = current 1219 current = execute(current) 1220 1221 // Use the outer references to resolve the subquery plan if it isn't resolved yet. 1222 val i = plans.iterator 1223 val afterResolve = current 1224 while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) { 1225 current = resolveOuterReferences(current, i.next()) 1226 } 1227 } while (!current.resolved && !current.fastEquals(previous)) 1228 1229 // Step 2: Pull out the predicates if the plan is resolved. 1230 if (current.resolved) { 1231 // Make sure the resolved query has the required number of output columns. This is only 1232 // needed for Scalar and IN subqueries. 1233 if (requiredColumns > 0 && requiredColumns != current.output.size) { 1234 failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + 1235 s"does not match the required number of columns ($requiredColumns)") 1236 } 1237 // Pullout predicates and construct a new plan. 1238 f.tupled(rewriteSubQuery(current, plans)) 1239 } else { 1240 e.withNewPlan(current) 1241 } 1242 } 1243 1244 /** 1245 * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS 1246 * expressions into PredicateSubquery expression once the are resolved. 1247 */ 1248 private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { 1249 plan transformExpressions { 1250 case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => 1251 resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) 1252 case e @ Exists(sub, exprId) => 1253 resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) 1254 case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => 1255 // Get the left hand side expressions. 1256 val expressions = e match { 1257 case cns : CreateNamedStruct => cns.valExprs 1258 case expr => Seq(expr) 1259 } 1260 resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => 1261 // Construct the IN conditions. 1262 val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) 1263 PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) 1264 } 1265 } 1266 } 1267 1268 /** 1269 * Resolve and rewrite all subqueries in an operator tree.. 1270 */ 1271 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1272 // In case of HAVING (a filter after an aggregate) we use both the aggregate and 1273 // its child for resolution. 1274 case f @ Filter(_, a: Aggregate) if f.childrenResolved => 1275 resolveSubQueries(f, Seq(a, a.child)) 1276 // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. 1277 case q: UnaryNode if q.childrenResolved => 1278 resolveSubQueries(q, q.children) 1279 } 1280 } 1281 1282 /** 1283 * Turns projections that contain aggregate expressions into aggregations. 1284 */ 1285 object GlobalAggregates extends Rule[LogicalPlan] { 1286 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1287 case Project(projectList, child) if containsAggregates(projectList) => 1288 Aggregate(Nil, projectList, child) 1289 } 1290 1291 def containsAggregates(exprs: Seq[Expression]): Boolean = { 1292 // Collect all Windowed Aggregate Expressions. 1293 val windowedAggExprs = exprs.flatMap { expr => 1294 expr.collect { 1295 case WindowExpression(ae: AggregateExpression, _) => ae 1296 } 1297 }.toSet 1298 1299 // Find the first Aggregate Expression that is not Windowed. 1300 exprs.exists(_.collectFirst { 1301 case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae 1302 }.isDefined) 1303 } 1304 } 1305 1306 /** 1307 * This rule finds aggregate expressions that are not in an aggregate operator. For example, 1308 * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the 1309 * underlying aggregate operator and then projected away after the original operator. 1310 */ 1311 object ResolveAggregateFunctions extends Rule[LogicalPlan] { 1312 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1313 case filter @ Filter(havingCondition, 1314 aggregate @ Aggregate(grouping, originalAggExprs, child)) 1315 if aggregate.resolved => 1316 1317 // Try resolving the condition of the filter as though it is in the aggregate clause 1318 try { 1319 val aggregatedCondition = 1320 Aggregate( 1321 grouping, 1322 Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, 1323 child) 1324 val resolvedOperator = execute(aggregatedCondition) 1325 def resolvedAggregateFilter = 1326 resolvedOperator 1327 .asInstanceOf[Aggregate] 1328 .aggregateExpressions.head 1329 1330 // If resolution was successful and we see the filter has an aggregate in it, add it to 1331 // the original aggregate operator. 1332 if (resolvedOperator.resolved) { 1333 // Try to replace all aggregate expressions in the filter by an alias. 1334 val aggregateExpressions = ArrayBuffer.empty[NamedExpression] 1335 val transformedAggregateFilter = resolvedAggregateFilter.transform { 1336 case ae: AggregateExpression => 1337 val alias = Alias(ae, ae.toString)() 1338 aggregateExpressions += alias 1339 alias.toAttribute 1340 // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. 1341 case e: Expression if grouping.exists(_.semanticEquals(e)) && 1342 !ResolveGroupingAnalytics.hasGroupingFunction(e) && 1343 !aggregate.output.exists(_.semanticEquals(e)) => 1344 e match { 1345 case ne: NamedExpression => 1346 aggregateExpressions += ne 1347 ne.toAttribute 1348 case _ => 1349 val alias = Alias(e, e.toString)() 1350 aggregateExpressions += alias 1351 alias.toAttribute 1352 } 1353 } 1354 1355 // Push the aggregate expressions into the aggregate (if any). 1356 if (aggregateExpressions.nonEmpty) { 1357 Project(aggregate.output, 1358 Filter(transformedAggregateFilter, 1359 aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) 1360 } else { 1361 filter 1362 } 1363 } else { 1364 filter 1365 } 1366 } catch { 1367 // Attempting to resolve in the aggregate can result in ambiguity. When this happens, 1368 // just return the original plan. 1369 case ae: AnalysisException => filter 1370 } 1371 1372 case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => 1373 1374 // Try resolving the ordering as though it is in the aggregate clause. 1375 try { 1376 val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) 1377 val aliasedOrdering = 1378 unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) 1379 val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) 1380 val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] 1381 val resolvedAliasedOrdering: Seq[Alias] = 1382 resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] 1383 1384 // If we pass the analysis check, then the ordering expressions should only reference to 1385 // aggregate expressions or grouping expressions, and it's safe to push them down to 1386 // Aggregate. 1387 checkAnalysis(resolvedAggregate) 1388 1389 val originalAggExprs = aggregate.aggregateExpressions.map( 1390 CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) 1391 1392 // If the ordering expression is same with original aggregate expression, we don't need 1393 // to push down this ordering expression and can reference the original aggregate 1394 // expression instead. 1395 val needsPushDown = ArrayBuffer.empty[NamedExpression] 1396 val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { 1397 case (evaluated, order) => 1398 val index = originalAggExprs.indexWhere { 1399 case Alias(child, _) => child semanticEquals evaluated.child 1400 case other => other semanticEquals evaluated.child 1401 } 1402 1403 if (index == -1) { 1404 needsPushDown += evaluated 1405 order.copy(child = evaluated.toAttribute) 1406 } else { 1407 order.copy(child = originalAggExprs(index).toAttribute) 1408 } 1409 } 1410 1411 val sortOrdersMap = unresolvedSortOrders 1412 .map(new TreeNodeRef(_)) 1413 .zip(evaluatedOrderings) 1414 .toMap 1415 val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) 1416 1417 // Since we don't rely on sort.resolved as the stop condition for this rule, 1418 // we need to check this and prevent applying this rule multiple times 1419 if (sortOrder == finalSortOrders) { 1420 sort 1421 } else { 1422 Project(aggregate.output, 1423 Sort(finalSortOrders, global, 1424 aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) 1425 } 1426 } catch { 1427 // Attempting to resolve in the aggregate can result in ambiguity. When this happens, 1428 // just return the original plan. 1429 case ae: AnalysisException => sort 1430 } 1431 } 1432 1433 def containsAggregate(condition: Expression): Boolean = { 1434 condition.find(_.isInstanceOf[AggregateExpression]).isDefined 1435 } 1436 } 1437 1438 /** 1439 * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] 1440 * operator under [[Project]]. 1441 * 1442 * This rule will throw [[AnalysisException]] for following cases: 1443 * 1. [[Generator]] is nested in expressions, e.g. `SELECT explode(list) + 1 FROM tbl` 1444 * 2. more than one [[Generator]] is found in projectList, 1445 * e.g. `SELECT explode(list), explode(list) FROM tbl` 1446 * 3. [[Generator]] is found in other operators that are not [[Project]] or [[Generate]], 1447 * e.g. `SELECT * FROM tbl SORT BY explode(list)` 1448 */ 1449 object ExtractGenerator extends Rule[LogicalPlan] { 1450 private def hasGenerator(expr: Expression): Boolean = { 1451 expr.find(_.isInstanceOf[Generator]).isDefined 1452 } 1453 1454 private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match { 1455 case UnresolvedAlias(_: Generator, _) => false 1456 case Alias(_: Generator, _) => false 1457 case MultiAlias(_: Generator, _) => false 1458 case other => hasGenerator(other) 1459 } 1460 1461 private def trimAlias(expr: NamedExpression): Expression = expr match { 1462 case UnresolvedAlias(child, _) => child 1463 case Alias(child, _) => child 1464 case MultiAlias(child, _) => child 1465 case _ => expr 1466 } 1467 1468 /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ 1469 private object AliasedGenerator { 1470 def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { 1471 case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) 1472 case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) 1473 case _ => None 1474 } 1475 } 1476 1477 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1478 case Project(projectList, _) if projectList.exists(hasNestedGenerator) => 1479 val nestedGenerator = projectList.find(hasNestedGenerator).get 1480 throw new AnalysisException("Generators are not supported when it's nested in " + 1481 "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) 1482 1483 case Project(projectList, _) if projectList.count(hasGenerator) > 1 => 1484 val generators = projectList.filter(hasGenerator).map(trimAlias) 1485 throw new AnalysisException("Only one generator allowed per select clause but found " + 1486 generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) 1487 1488 case p @ Project(projectList, child) => 1489 // Holds the resolved generator, if one exists in the project list. 1490 var resolvedGenerator: Generate = null 1491 1492 val newProjectList = projectList.flatMap { 1493 case AliasedGenerator(generator, names) if generator.childrenResolved => 1494 // It's a sanity check, this should not happen as the previous case will throw 1495 // exception earlier. 1496 assert(resolvedGenerator == null, "More than one generator found in SELECT.") 1497 1498 resolvedGenerator = 1499 Generate( 1500 generator, 1501 join = projectList.size > 1, // Only join if there are other expressions in SELECT. 1502 outer = false, 1503 qualifier = None, 1504 generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), 1505 child) 1506 1507 resolvedGenerator.generatorOutput 1508 case other => other :: Nil 1509 } 1510 1511 if (resolvedGenerator != null) { 1512 Project(newProjectList, resolvedGenerator) 1513 } else { 1514 p 1515 } 1516 1517 case g: Generate => g 1518 1519 case p if p.expressions.exists(hasGenerator) => 1520 throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + 1521 "got: " + p.simpleString) 1522 } 1523 } 1524 1525 /** 1526 * Rewrites table generating expressions that either need one or more of the following in order 1527 * to be resolved: 1528 * - concrete attribute references for their output. 1529 * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). 1530 * 1531 * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions 1532 * that wrap the [[Generator]]. 1533 */ 1534 object ResolveGenerate extends Rule[LogicalPlan] { 1535 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1536 case g: Generate if !g.child.resolved || !g.generator.resolved => g 1537 case g: Generate if !g.resolved => 1538 g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) 1539 } 1540 1541 /** 1542 * Construct the output attributes for a [[Generator]], given a list of names. If the list of 1543 * names is empty names are assigned from field names in generator. 1544 */ 1545 private[analysis] def makeGeneratorOutput( 1546 generator: Generator, 1547 names: Seq[String]): Seq[Attribute] = { 1548 val elementAttrs = generator.elementSchema.toAttributes 1549 1550 if (names.length == elementAttrs.length) { 1551 names.zip(elementAttrs).map { 1552 case (name, attr) => attr.withName(name) 1553 } 1554 } else if (names.isEmpty) { 1555 elementAttrs 1556 } else { 1557 failAnalysis( 1558 "The number of aliases supplied in the AS clause does not match the number of columns " + 1559 s"output by the UDTF expected ${elementAttrs.size} aliases but got " + 1560 s"${names.mkString(",")} ") 1561 } 1562 } 1563 } 1564 1565 /** 1566 * Fixes nullability of Attributes in a resolved LogicalPlan by using the nullability of 1567 * corresponding Attributes of its children output Attributes. This step is needed because 1568 * users can use a resolved AttributeReference in the Dataset API and outer joins 1569 * can change the nullability of an AttribtueReference. Without the fix, a nullable column's 1570 * nullable field can be actually set as non-nullable, which cause illegal optimization 1571 * (e.g., NULL propagation) and wrong answers. 1572 * See SPARK-13484 and SPARK-13801 for the concrete queries of this case. 1573 */ 1574 object FixNullability extends Rule[LogicalPlan] { 1575 1576 def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { 1577 case p if !p.resolved => p // Skip unresolved nodes. 1578 case p: LogicalPlan if p.resolved => 1579 val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { 1580 case (exprId, attributes) => 1581 // If there are multiple Attributes having the same ExprId, we need to resolve 1582 // the conflict of nullable field. We do not really expect this happen. 1583 val nullable = attributes.exists(_.nullable) 1584 attributes.map(attr => attr.withNullability(nullable)) 1585 }.toSeq 1586 // At here, we create an AttributeMap that only compare the exprId for the lookup 1587 // operation. So, we can find the corresponding input attribute's nullability. 1588 val attributeMap = AttributeMap[Attribute](childrenOutput.map(attr => attr -> attr)) 1589 // For an Attribute used by the current LogicalPlan, if it is from its children, 1590 // we fix the nullable field by using the nullability setting of the corresponding 1591 // output Attribute from the children. 1592 p.transformExpressions { 1593 case attr: Attribute if attributeMap.contains(attr) => 1594 attr.withNullability(attributeMap(attr).nullable) 1595 } 1596 } 1597 } 1598 1599 /** 1600 * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and 1601 * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]] 1602 * operators for every distinct [[WindowSpecDefinition]]. 1603 * 1604 * This rule handles three cases: 1605 * - A [[Project]] having [[WindowExpression]]s in its projectList; 1606 * - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions. 1607 * - A [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING 1608 * clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions. 1609 * Note: If there is a GROUP BY clause in the query, aggregations and corresponding 1610 * filters (expressions in the HAVING clause) should be evaluated before any 1611 * [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be 1612 * evaluated after all [[WindowExpression]]s. 1613 * 1614 * For every case, the transformation works as follows: 1615 * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions 1616 * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for 1617 * all regular expressions. 1618 * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. 1619 * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts 1620 * it into the plan tree. 1621 */ 1622 object ExtractWindowExpressions extends Rule[LogicalPlan] { 1623 private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = 1624 projectList.exists(hasWindowFunction) 1625 1626 private def hasWindowFunction(expr: NamedExpression): Boolean = { 1627 expr.find { 1628 case window: WindowExpression => true 1629 case _ => false 1630 }.isDefined 1631 } 1632 1633 /** 1634 * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and 1635 * other regular expressions that do not contain any window expression. For example, for 1636 * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract 1637 * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in 1638 * the window expression as attribute references. So, the first returned value will be 1639 * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be 1640 * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. 1641 * 1642 * @return (seq of expressions containing at lease one window expressions, 1643 * seq of non-window expressions) 1644 */ 1645 private def extract( 1646 expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { 1647 // First, we partition the input expressions to two part. For the first part, 1648 // every expression in it contain at least one WindowExpression. 1649 // Expressions in the second part do not have any WindowExpression. 1650 val (expressionsWithWindowFunctions, regularExpressions) = 1651 expressions.partition(hasWindowFunction) 1652 1653 // Then, we need to extract those regular expressions used in the WindowExpression. 1654 // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), 1655 // we need to make sure that col1 to col5 are all projected from the child of the Window 1656 // operator. 1657 val extractedExprBuffer = new ArrayBuffer[NamedExpression]() 1658 def extractExpr(expr: Expression): Expression = expr match { 1659 case ne: NamedExpression => 1660 // If a named expression is not in regularExpressions, add it to 1661 // extractedExprBuffer and replace it with an AttributeReference. 1662 val missingExpr = 1663 AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) 1664 if (missingExpr.nonEmpty) { 1665 extractedExprBuffer += ne 1666 } 1667 // alias will be cleaned in the rule CleanupAliases 1668 ne 1669 case e: Expression if e.foldable => 1670 e // No need to create an attribute reference if it will be evaluated as a Literal. 1671 case e: Expression => 1672 // For other expressions, we extract it and replace it with an AttributeReference (with 1673 // an internal column name, e.g. "_w0"). 1674 val withName = Alias(e, s"_w${extractedExprBuffer.length}")() 1675 extractedExprBuffer += withName 1676 withName.toAttribute 1677 } 1678 1679 // Now, we extract regular expressions from expressionsWithWindowFunctions 1680 // by using extractExpr. 1681 val seenWindowAggregates = new ArrayBuffer[AggregateExpression] 1682 val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { 1683 _.transform { 1684 // Extracts children expressions of a WindowFunction (input parameters of 1685 // a WindowFunction). 1686 case wf: WindowFunction => 1687 val newChildren = wf.children.map(extractExpr) 1688 wf.withNewChildren(newChildren) 1689 1690 // Extracts expressions from the partition spec and order spec. 1691 case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => 1692 val newPartitionSpec = partitionSpec.map(extractExpr) 1693 val newOrderSpec = orderSpec.map { so => 1694 val newChild = extractExpr(so.child) 1695 so.copy(child = newChild) 1696 } 1697 wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) 1698 1699 // Extract Windowed AggregateExpression 1700 case we @ WindowExpression( 1701 ae @ AggregateExpression(function, _, _, _), 1702 spec: WindowSpecDefinition) => 1703 val newChildren = function.children.map(extractExpr) 1704 val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] 1705 val newAgg = ae.copy(aggregateFunction = newFunction) 1706 seenWindowAggregates += newAgg 1707 WindowExpression(newAgg, spec) 1708 1709 // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), 1710 // we need to extract SUM(x). 1711 case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => 1712 val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() 1713 extractedExprBuffer += withName 1714 withName.toAttribute 1715 1716 // Extracts other attributes 1717 case attr: Attribute => extractExpr(attr) 1718 1719 }.asInstanceOf[NamedExpression] 1720 } 1721 1722 (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer) 1723 } // end of extract 1724 1725 /** 1726 * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. 1727 */ 1728 private def addWindow( 1729 expressionsWithWindowFunctions: Seq[NamedExpression], 1730 child: LogicalPlan): LogicalPlan = { 1731 // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions 1732 // and put those extracted WindowExpressions to extractedWindowExprBuffer. 1733 // This step is needed because it is possible that an expression contains multiple 1734 // WindowExpressions with different Window Specs. 1735 // After extracting WindowExpressions, we need to construct a project list to generate 1736 // expressionsWithWindowFunctions based on extractedWindowExprBuffer. 1737 // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract 1738 // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to 1739 // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". 1740 // Then, the projectList will be [_we0/_we1]. 1741 val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() 1742 val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { 1743 // We need to use transformDown because we want to trigger 1744 // "case alias @ Alias(window: WindowExpression, _)" first. 1745 _.transformDown { 1746 case alias @ Alias(window: WindowExpression, _) => 1747 // If a WindowExpression has an assigned alias, just use it. 1748 extractedWindowExprBuffer += alias 1749 alias.toAttribute 1750 case window: WindowExpression => 1751 // If there is no alias assigned to the WindowExpressions. We create an 1752 // internal column. 1753 val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() 1754 extractedWindowExprBuffer += withName 1755 withName.toAttribute 1756 }.asInstanceOf[NamedExpression] 1757 } 1758 1759 // Second, we group extractedWindowExprBuffer based on their Partition and Order Specs. 1760 val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => 1761 val distinctWindowSpec = expr.collect { 1762 case window: WindowExpression => window.windowSpec 1763 }.distinct 1764 1765 // We do a final check and see if we only have a single Window Spec defined in an 1766 // expressions. 1767 if (distinctWindowSpec.isEmpty) { 1768 failAnalysis(s"$expr does not have any WindowExpression.") 1769 } else if (distinctWindowSpec.length > 1) { 1770 // newExpressionsWithWindowFunctions only have expressions with a single 1771 // WindowExpression. If we reach here, we have a bug. 1772 failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + 1773 s"Please file a bug report with this error message, stack trace, and the query.") 1774 } else { 1775 val spec = distinctWindowSpec.head 1776 (spec.partitionSpec, spec.orderSpec) 1777 } 1778 }.toSeq 1779 1780 // Third, we aggregate them by adding each Window operator for each Window Spec and then 1781 // setting this to the child of the next Window operator. 1782 val windowOps = 1783 groupedWindowExpressions.foldLeft(child) { 1784 case (last, ((partitionSpec, orderSpec), windowExpressions)) => 1785 Window(windowExpressions, partitionSpec, orderSpec, last) 1786 } 1787 1788 // Finally, we create a Project to output windowOps's output 1789 // newExpressionsWithWindowFunctions. 1790 Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps) 1791 } // end of addWindow 1792 1793 // We have to use transformDown at here to make sure the rule of 1794 // "Aggregate with Having clause" will be triggered. 1795 def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { 1796 1797 // Aggregate with Having clause. This rule works with an unresolved Aggregate because 1798 // a resolved Aggregate will not have Window Functions. 1799 case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) 1800 if child.resolved && 1801 hasWindowFunction(aggregateExprs) && 1802 a.expressions.forall(_.resolved) => 1803 val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) 1804 // Create an Aggregate operator to evaluate aggregation functions. 1805 val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) 1806 // Add a Filter operator for conditions in the Having clause. 1807 val withFilter = Filter(condition, withAggregate) 1808 val withWindow = addWindow(windowExpressions, withFilter) 1809 1810 // Finally, generate output columns according to the original projectList. 1811 val finalProjectList = aggregateExprs.map(_.toAttribute) 1812 Project(finalProjectList, withWindow) 1813 1814 case p: LogicalPlan if !p.childrenResolved => p 1815 1816 // Aggregate without Having clause. 1817 case a @ Aggregate(groupingExprs, aggregateExprs, child) 1818 if hasWindowFunction(aggregateExprs) && 1819 a.expressions.forall(_.resolved) => 1820 val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) 1821 // Create an Aggregate operator to evaluate aggregation functions. 1822 val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) 1823 // Add Window operators. 1824 val withWindow = addWindow(windowExpressions, withAggregate) 1825 1826 // Finally, generate output columns according to the original projectList. 1827 val finalProjectList = aggregateExprs.map(_.toAttribute) 1828 Project(finalProjectList, withWindow) 1829 1830 // We only extract Window Expressions after all expressions of the Project 1831 // have been resolved. 1832 case p @ Project(projectList, child) 1833 if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) => 1834 val (windowExpressions, regularExpressions) = extract(projectList) 1835 // We add a project to get all needed expressions for window expressions from the child 1836 // of the original Project operator. 1837 val withProject = Project(regularExpressions, child) 1838 // Add Window operators. 1839 val withWindow = addWindow(windowExpressions, withProject) 1840 1841 // Finally, generate output columns according to the original projectList. 1842 val finalProjectList = projectList.map(_.toAttribute) 1843 Project(finalProjectList, withWindow) 1844 } 1845 } 1846 1847 /** 1848 * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, 1849 * put them into an inner Project and finally project them away at the outer Project. 1850 */ 1851 object PullOutNondeterministic extends Rule[LogicalPlan] { 1852 override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1853 case p if !p.resolved => p // Skip unresolved nodes. 1854 case p: Project => p 1855 case f: Filter => f 1856 1857 case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => 1858 val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) 1859 val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) 1860 a.transformExpressions { case e => 1861 nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) 1862 }.copy(child = newChild) 1863 1864 // todo: It's hard to write a general rule to pull out nondeterministic expressions 1865 // from LogicalPlan, currently we only do it for UnaryNode which has same output 1866 // schema with its child. 1867 case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => 1868 val nondeterToAttr = getNondeterToAttr(p.expressions) 1869 val newPlan = p.transformExpressions { case e => 1870 nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) 1871 } 1872 val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) 1873 Project(p.output, newPlan.withNewChildren(newChild :: Nil)) 1874 } 1875 1876 private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { 1877 exprs.filterNot(_.deterministic).flatMap { expr => 1878 val leafNondeterministic = expr.collect { case n: Nondeterministic => n } 1879 leafNondeterministic.distinct.map { e => 1880 val ne = e match { 1881 case n: NamedExpression => n 1882 case _ => Alias(e, "_nondeterministic")(isGenerated = true) 1883 } 1884 e -> ne 1885 } 1886 }.toMap 1887 } 1888 } 1889 1890 /** 1891 * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the 1892 * null check. When user defines a UDF with primitive parameters, there is no way to tell if the 1893 * primitive parameter is null or not, so here we assume the primitive input is null-propagatable 1894 * and we should return null if the input is null. 1895 */ 1896 object HandleNullInputsForUDF extends Rule[LogicalPlan] { 1897 override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1898 case p if !p.resolved => p // Skip unresolved nodes. 1899 1900 case p => p transformExpressionsUp { 1901 1902 case udf @ ScalaUDF(func, _, inputs, _, _) => 1903 val parameterTypes = ScalaReflection.getParameterTypes(func) 1904 assert(parameterTypes.length == inputs.length) 1905 1906 val inputsNullCheck = parameterTypes.zip(inputs) 1907 // TODO: skip null handling for not-nullable primitive inputs after we can completely 1908 // trust the `nullable` information. 1909 // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } 1910 .filter { case (cls, _) => cls.isPrimitive } 1911 .map { case (_, expr) => IsNull(expr) } 1912 .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) 1913 inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) 1914 } 1915 } 1916 } 1917 1918 /** 1919 * Check and add proper window frames for all window functions. 1920 */ 1921 object ResolveWindowFrame extends Rule[LogicalPlan] { 1922 def apply(plan: LogicalPlan): LogicalPlan = plan transform { 1923 case logical: LogicalPlan => logical transformExpressions { 1924 case WindowExpression(wf: WindowFunction, 1925 WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) 1926 if wf.frame != UnspecifiedFrame && wf.frame != f => 1927 failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") 1928 case WindowExpression(wf: WindowFunction, 1929 s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) 1930 if wf.frame != UnspecifiedFrame => 1931 WindowExpression(wf, s.copy(frameSpecification = wf.frame)) 1932 case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) 1933 if e.resolved => 1934 val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) 1935 we.copy(windowSpec = s.copy(frameSpecification = frame)) 1936 } 1937 } 1938 } 1939 1940 /** 1941 * Check and add order to [[AggregateWindowFunction]]s. 1942 */ 1943 object ResolveWindowOrder extends Rule[LogicalPlan] { 1944 def apply(plan: LogicalPlan): LogicalPlan = plan transform { 1945 case logical: LogicalPlan => logical transformExpressions { 1946 case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => 1947 failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + 1948 s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + 1949 s"ORDER BY window_ordering) from table") 1950 case WindowExpression(rank: RankLike, spec) if spec.resolved => 1951 val order = spec.orderSpec.map(_.child) 1952 WindowExpression(rank.withOrder(order), spec) 1953 } 1954 } 1955 } 1956 1957 /** 1958 * Removes natural or using joins by calculating output columns based on output from two sides, 1959 * Then apply a Project on a normal Join to eliminate natural or using join. 1960 */ 1961 object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { 1962 override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 1963 case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) 1964 if left.resolved && right.resolved && j.duplicateResolved => 1965 commonNaturalJoinProcessing(left, right, joinType, usingCols, None) 1966 case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => 1967 // find common column names from both sides 1968 val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) 1969 commonNaturalJoinProcessing(left, right, joinType, joinNames, condition) 1970 } 1971 } 1972 1973 private def commonNaturalJoinProcessing( 1974 left: LogicalPlan, 1975 right: LogicalPlan, 1976 joinType: JoinType, 1977 joinNames: Seq[String], 1978 condition: Option[Expression]) = { 1979 val leftKeys = joinNames.map { keyName => 1980 left.output.find(attr => resolver(attr.name, keyName)).getOrElse { 1981 throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + 1982 s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]") 1983 } 1984 } 1985 val rightKeys = joinNames.map { keyName => 1986 right.output.find(attr => resolver(attr.name, keyName)).getOrElse { 1987 throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " + 1988 s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]") 1989 } 1990 } 1991 val joinPairs = leftKeys.zip(rightKeys) 1992 1993 val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) 1994 1995 // columns not in joinPairs 1996 val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) 1997 val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) 1998 1999 // the output list looks like: join keys, columns from left, columns from right 2000 val projectList = joinType match { 2001 case LeftOuter => 2002 leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) 2003 case LeftExistence(_) => 2004 leftKeys ++ lUniqueOutput 2005 case RightOuter => 2006 rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput 2007 case FullOuter => 2008 // in full outer join, joinCols should be non-null if there is. 2009 val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } 2010 joinedCols ++ 2011 lUniqueOutput.map(_.withNullability(true)) ++ 2012 rUniqueOutput.map(_.withNullability(true)) 2013 case _ : InnerLike => 2014 leftKeys ++ lUniqueOutput ++ rUniqueOutput 2015 case _ => 2016 sys.error("Unsupported natural join type " + joinType) 2017 } 2018 // use Project to trim unnecessary fields 2019 Project(projectList, Join(left, right, joinType, newCondition)) 2020 } 2021 2022 /** 2023 * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved 2024 * to the given input attributes. 2025 */ 2026 object ResolveDeserializer extends Rule[LogicalPlan] { 2027 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 2028 case p if !p.childrenResolved => p 2029 case p if p.resolved => p 2030 2031 case p => p transformExpressions { 2032 case UnresolvedDeserializer(deserializer, inputAttributes) => 2033 val inputs = if (inputAttributes.isEmpty) { 2034 p.children.flatMap(_.output) 2035 } else { 2036 inputAttributes 2037 } 2038 2039 validateTopLevelTupleFields(deserializer, inputs) 2040 val resolved = resolveExpression( 2041 deserializer, LocalRelation(inputs), throws = true) 2042 validateNestedTupleFields(resolved) 2043 resolved 2044 } 2045 } 2046 2047 private def fail(schema: StructType, maxOrdinal: Int): Unit = { 2048 throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + 2049 "but failed as the number of fields does not line up.") 2050 } 2051 2052 /** 2053 * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column 2054 * by position. However, the actual number of columns may be different from the number of Tuple 2055 * fields. This method is used to check the number of columns and fields, and throw an 2056 * exception if they do not match. 2057 */ 2058 private def validateTopLevelTupleFields( 2059 deserializer: Expression, inputs: Seq[Attribute]): Unit = { 2060 val ordinals = deserializer.collect { 2061 case GetColumnByOrdinal(ordinal, _) => ordinal 2062 }.distinct.sorted 2063 2064 if (ordinals.nonEmpty && ordinals != inputs.indices) { 2065 fail(inputs.toStructType, ordinals.last) 2066 } 2067 } 2068 2069 /** 2070 * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field 2071 * by position. However, the actual number of struct fields may be different from the number 2072 * of nested Tuple fields. This method is used to check the number of struct fields and nested 2073 * Tuple fields, and throw an exception if they do not match. 2074 */ 2075 private def validateNestedTupleFields(deserializer: Expression): Unit = { 2076 val structChildToOrdinals = deserializer 2077 // There are 2 kinds of `GetStructField`: 2078 // 1. resolved from `UnresolvedExtractValue`, and it will have a `name` property. 2079 // 2. created when we build deserializer expression for nested tuple, no `name` property. 2080 // Here we want to validate the ordinals of nested tuple, so we should only catch 2081 // `GetStructField` without the name property. 2082 .collect { case g: GetStructField if g.name.isEmpty => g } 2083 .groupBy(_.child) 2084 .mapValues(_.map(_.ordinal).distinct.sorted) 2085 2086 structChildToOrdinals.foreach { case (expr, ordinals) => 2087 val schema = expr.dataType.asInstanceOf[StructType] 2088 if (ordinals != schema.indices) { 2089 fail(schema, ordinals.last) 2090 } 2091 } 2092 } 2093 } 2094 2095 /** 2096 * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being 2097 * constructed is an inner class. 2098 */ 2099 object ResolveNewInstance extends Rule[LogicalPlan] { 2100 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 2101 case p if !p.childrenResolved => p 2102 case p if p.resolved => p 2103 2104 case p => p transformExpressions { 2105 case n: NewInstance if n.childrenResolved && !n.resolved => 2106 val outer = OuterScopes.getOuterScope(n.cls) 2107 if (outer == null) { 2108 throw new AnalysisException( 2109 s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + 2110 "access to the scope that this class was defined in.\n" + 2111 "Try moving this class out of its parent class.") 2112 } 2113 n.copy(outerPointer = Some(outer)) 2114 } 2115 } 2116 } 2117 2118 /** 2119 * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. 2120 */ 2121 object ResolveUpCast extends Rule[LogicalPlan] { 2122 private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { 2123 throw new AnalysisException(s"Cannot up cast ${from.sql} from " + 2124 s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + 2125 "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + 2126 "You can either add an explicit cast to the input data or choose a higher precision " + 2127 "type of the field in the target object") 2128 } 2129 2130 private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { 2131 val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) 2132 val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) 2133 toPrecedence > 0 && fromPrecedence > toPrecedence 2134 } 2135 2136 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 2137 case p if !p.childrenResolved => p 2138 case p if p.resolved => p 2139 2140 case p => p transformExpressions { 2141 case u @ UpCast(child, _, _) if !child.resolved => u 2142 2143 case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { 2144 case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => 2145 fail(child, to, walkedTypePath) 2146 case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => 2147 fail(child, to, walkedTypePath) 2148 case (from, to) if illegalNumericPrecedence(from, to) => 2149 fail(child, to, walkedTypePath) 2150 case (TimestampType, DateType) => 2151 fail(child, DateType, walkedTypePath) 2152 case (StringType, to: NumericType) => 2153 fail(child, to, walkedTypePath) 2154 case _ => Cast(child, dataType.asNullable) 2155 } 2156 } 2157 } 2158 } 2159} 2160 2161/** 2162 * Removes [[SubqueryAlias]] operators from the plan. Subqueries are only required to provide 2163 * scoping information for attributes and can be removed once analysis is complete. 2164 */ 2165object EliminateSubqueryAliases extends Rule[LogicalPlan] { 2166 def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { 2167 case SubqueryAlias(_, child, _) => child 2168 } 2169} 2170 2171/** 2172 * Removes [[Union]] operators from the plan if it just has one child. 2173 */ 2174object EliminateUnions extends Rule[LogicalPlan] { 2175 def apply(plan: LogicalPlan): LogicalPlan = plan transform { 2176 case Union(children) if children.size == 1 => children.head 2177 } 2178} 2179 2180/** 2181 * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level 2182 * expression in Project(project list) or Aggregate(aggregate expressions) or 2183 * Window(window expressions). 2184 */ 2185object CleanupAliases extends Rule[LogicalPlan] { 2186 private def trimAliases(e: Expression): Expression = { 2187 e.transformDown { 2188 case Alias(child, _) => child 2189 } 2190 } 2191 2192 def trimNonTopLevelAliases(e: Expression): Expression = e match { 2193 case a: Alias => 2194 a.withNewChildren(trimAliases(a.child) :: Nil) 2195 case other => trimAliases(other) 2196 } 2197 2198 override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 2199 case Project(projectList, child) => 2200 val cleanedProjectList = 2201 projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) 2202 Project(cleanedProjectList, child) 2203 2204 case Aggregate(grouping, aggs, child) => 2205 val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) 2206 Aggregate(grouping.map(trimAliases), cleanedAggs, child) 2207 2208 case w @ Window(windowExprs, partitionSpec, orderSpec, child) => 2209 val cleanedWindowExprs = 2210 windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) 2211 Window(cleanedWindowExprs, partitionSpec.map(trimAliases), 2212 orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) 2213 2214 // Operators that operate on objects should only have expressions from encoders, which should 2215 // never have extra aliases. 2216 case o: ObjectConsumer => o 2217 case o: ObjectProducer => o 2218 case a: AppendColumns => a 2219 2220 case other => 2221 other transformExpressionsDown { 2222 case Alias(child, _) => child 2223 } 2224 } 2225} 2226 2227/** 2228 * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to 2229 * figure out how many windows a time column can map to, we over-estimate the number of windows and 2230 * filter out the rows where the time column is not inside the time window. 2231 */ 2232object TimeWindowing extends Rule[LogicalPlan] { 2233 import org.apache.spark.sql.catalyst.dsl.expressions._ 2234 2235 private final val WINDOW_START = "start" 2236 private final val WINDOW_END = "end" 2237 2238 /** 2239 * Generates the logical plan for generating window ranges on a timestamp column. Without 2240 * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many 2241 * window ranges a timestamp will map to given all possible combinations of a window duration, 2242 * slide duration and start time (offset). Therefore, we express and over-estimate the number of 2243 * windows there may be, and filter the valid windows. We use last Project operator to group 2244 * the window columns into a struct so they can be accessed as `window.start` and `window.end`. 2245 * 2246 * The windows are calculated as below: 2247 * maxNumOverlapping <- ceil(windowDuration / slideDuration) 2248 * for (i <- 0 until maxNumOverlapping) 2249 * windowId <- ceil((timestamp - startTime) / slideDuration) 2250 * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime 2251 * windowEnd <- windowStart + windowDuration 2252 * return windowStart, windowEnd 2253 * 2254 * This behaves as follows for the given parameters for the time: 12:05. The valid windows are 2255 * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the 2256 * Filter operator. 2257 * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m 2258 * 11:55 - 12:07 + 11:52 - 12:04 x 2259 * 12:00 - 12:12 + 11:57 - 12:09 + 2260 * 12:05 - 12:17 + 12:02 - 12:14 + 2261 * 2262 * @param plan The logical plan 2263 * @return the logical plan that will generate the time windows using the Expand operator, with 2264 * the Filter operator for correctness and Project for usability. 2265 */ 2266 def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { 2267 case p: LogicalPlan if p.children.size == 1 => 2268 val child = p.children.head 2269 val windowExpressions = 2270 p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. 2271 2272 // Only support a single window expression for now 2273 if (windowExpressions.size == 1 && 2274 windowExpressions.head.timeColumn.resolved && 2275 windowExpressions.head.checkInputDataTypes().isSuccess) { 2276 val window = windowExpressions.head 2277 2278 val metadata = window.timeColumn match { 2279 case a: Attribute => a.metadata 2280 case _ => Metadata.empty 2281 } 2282 val windowAttr = 2283 AttributeReference("window", window.dataType, metadata = metadata)() 2284 2285 val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt 2286 val windows = Seq.tabulate(maxNumOverlapping + 1) { i => 2287 val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / 2288 window.slideDuration) 2289 val windowStart = (windowId + i - maxNumOverlapping) * 2290 window.slideDuration + window.startTime 2291 val windowEnd = windowStart + window.windowDuration 2292 2293 CreateNamedStruct( 2294 Literal(WINDOW_START) :: windowStart :: 2295 Literal(WINDOW_END) :: windowEnd :: Nil) 2296 } 2297 2298 val projections = windows.map(_ +: p.children.head.output) 2299 2300 val filterExpr = 2301 window.timeColumn >= windowAttr.getField(WINDOW_START) && 2302 window.timeColumn < windowAttr.getField(WINDOW_END) 2303 2304 val expandedPlan = 2305 Filter(filterExpr, 2306 Expand(projections, windowAttr +: child.output, child)) 2307 2308 val substitutedPlan = p transformExpressions { 2309 case t: TimeWindow => windowAttr 2310 } 2311 2312 substitutedPlan.withNewChildren(expandedPlan :: Nil) 2313 } else if (windowExpressions.size > 1) { 2314 p.failAnalysis("Multiple time window expressions would result in a cartesian product " + 2315 "of rows, therefore they are not currently not supported.") 2316 } else { 2317 p // Return unchanged. Analyzer will throw exception later 2318 } 2319 } 2320} 2321 2322/** 2323 * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. 2324 */ 2325object ResolveCreateNamedStruct extends Rule[LogicalPlan] { 2326 override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { 2327 case e: CreateNamedStruct if !e.resolved => 2328 val children = e.children.grouped(2).flatMap { 2329 case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => 2330 Seq(Literal(e.name), e) 2331 case kv => 2332 kv 2333 } 2334 CreateNamedStruct(children.toList) 2335 } 2336} 2337