1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.analysis
19
20import scala.annotation.tailrec
21import scala.collection.mutable.ArrayBuffer
22
23import org.apache.spark.sql.AnalysisException
24import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
25import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
26import org.apache.spark.sql.catalyst.encoders.OuterScopes
27import org.apache.spark.sql.catalyst.expressions._
28import org.apache.spark.sql.catalyst.expressions.aggregate._
29import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
30import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
31import org.apache.spark.sql.catalyst.plans._
32import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
33import org.apache.spark.sql.catalyst.rules._
34import org.apache.spark.sql.catalyst.trees.TreeNodeRef
35import org.apache.spark.sql.catalyst.util.toPrettySQL
36import org.apache.spark.sql.types._
37
38/**
39 * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
40 * Used for testing when all relations are already filled in and the analyzer needs only
41 * to resolve attribute references.
42 */
43object SimpleAnalyzer extends Analyzer(
44    new SessionCatalog(
45      new InMemoryCatalog,
46      EmptyFunctionRegistry,
47      new SimpleCatalystConf(caseSensitiveAnalysis = true)) {
48      override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {}
49    },
50    new SimpleCatalystConf(caseSensitiveAnalysis = true))
51
52/**
53 * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
54 * [[UnresolvedRelation]]s into fully typed objects using information in a
55 * [[SessionCatalog]] and a [[FunctionRegistry]].
56 */
57class Analyzer(
58    catalog: SessionCatalog,
59    conf: CatalystConf,
60    maxIterations: Int)
61  extends RuleExecutor[LogicalPlan] with CheckAnalysis {
62
63  def this(catalog: SessionCatalog, conf: CatalystConf) = {
64    this(catalog, conf, conf.optimizerMaxIterations)
65  }
66
67  def resolver: Resolver = conf.resolver
68
69  protected val fixedPoint = FixedPoint(maxIterations)
70
71  /**
72   * Override to provide additional rules for the "Resolution" batch.
73   */
74  val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
75
76  lazy val batches: Seq[Batch] = Seq(
77    Batch("Substitution", fixedPoint,
78      CTESubstitution,
79      WindowsSubstitution,
80      EliminateUnions,
81      new SubstituteUnresolvedOrdinals(conf)),
82    Batch("Resolution", fixedPoint,
83      ResolveTableValuedFunctions ::
84      ResolveRelations ::
85      ResolveReferences ::
86      ResolveCreateNamedStruct ::
87      ResolveDeserializer ::
88      ResolveNewInstance ::
89      ResolveUpCast ::
90      ResolveGroupingAnalytics ::
91      ResolvePivot ::
92      ResolveOrdinalInOrderByAndGroupBy ::
93      ResolveMissingReferences ::
94      ExtractGenerator ::
95      ResolveGenerate ::
96      ResolveFunctions ::
97      ResolveAliases ::
98      ResolveSubquery ::
99      ResolveWindowOrder ::
100      ResolveWindowFrame ::
101      ResolveNaturalAndUsingJoin ::
102      ExtractWindowExpressions ::
103      GlobalAggregates ::
104      ResolveAggregateFunctions ::
105      TimeWindowing ::
106      ResolveInlineTables ::
107      TypeCoercion.typeCoercionRules ++
108      extendedResolutionRules : _*),
109    Batch("Nondeterministic", Once,
110      PullOutNondeterministic),
111    Batch("UDF", Once,
112      HandleNullInputsForUDF),
113    Batch("FixNullability", Once,
114      FixNullability),
115    Batch("Cleanup", fixedPoint,
116      CleanupAliases)
117  )
118
119  /**
120   * Analyze cte definitions and substitute child plan with analyzed cte definitions.
121   */
122  object CTESubstitution extends Rule[LogicalPlan] {
123    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators  {
124      case With(child, relations) =>
125        substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
126          case (resolved, (name, relation)) =>
127            resolved :+ name -> execute(substituteCTE(relation, resolved))
128        })
129      case other => other
130    }
131
132    def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = {
133      plan transformDown {
134        case u : UnresolvedRelation =>
135          val substituted = cteRelations.find(x => resolver(x._1, u.tableIdentifier.table))
136            .map(_._2).map { relation =>
137              val withAlias = u.alias.map(SubqueryAlias(_, relation, None))
138              withAlias.getOrElse(relation)
139            }
140          substituted.getOrElse(u)
141        case other =>
142          // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
143          other transformExpressions {
144            case e: SubqueryExpression =>
145              e.withNewPlan(substituteCTE(e.plan, cteRelations))
146          }
147      }
148    }
149  }
150
151  /**
152   * Substitute child plan with WindowSpecDefinitions.
153   */
154  object WindowsSubstitution extends Rule[LogicalPlan] {
155    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
156      // Lookup WindowSpecDefinitions. This rule works with unresolved children.
157      case WithWindowDefinition(windowDefinitions, child) =>
158        child.transform {
159          case p => p.transformExpressions {
160            case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) =>
161              val errorMessage =
162                s"Window specification $windowName is not defined in the WINDOW clause."
163              val windowSpecDefinition =
164                windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage))
165              WindowExpression(c, windowSpecDefinition)
166          }
167        }
168    }
169  }
170
171  /**
172   * Replaces [[UnresolvedAlias]]s with concrete aliases.
173   */
174  object ResolveAliases extends Rule[LogicalPlan] {
175    private def assignAliases(exprs: Seq[NamedExpression]) = {
176      exprs.zipWithIndex.map {
177        case (expr, i) =>
178          expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) =>
179            child match {
180              case ne: NamedExpression => ne
181              case e if !e.resolved => u
182              case g: Generator => MultiAlias(g, Nil)
183              case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
184              case e: ExtractValue => Alias(e, toPrettySQL(e))()
185              case e if optGenAliasFunc.isDefined =>
186                Alias(child, optGenAliasFunc.get.apply(e))()
187              case e => Alias(e, toPrettySQL(e))()
188            }
189          }
190      }.asInstanceOf[Seq[NamedExpression]]
191    }
192
193    private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
194      exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
195
196    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
197      case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
198        Aggregate(groups, assignAliases(aggs), child)
199
200      case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
201        g.copy(aggregations = assignAliases(g.aggregations))
202
203      case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
204        if child.resolved && hasUnresolvedAlias(groupByExprs) =>
205        Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
206
207      case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
208        Project(assignAliases(projectList), child)
209    }
210  }
211
212  object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
213    /*
214     *  GROUP BY a, b, c WITH ROLLUP
215     *  is equivalent to
216     *  GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( ) ).
217     *  Group Count: N + 1 (N is the number of group expressions)
218     *
219     *  We need to get all of its subsets for the rule described above, the subset is
220     *  represented as the bit masks.
221     */
222    def bitmasks(r: Rollup): Seq[Int] = {
223      Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1)
224    }
225
226    /*
227     *  GROUP BY a, b, c WITH CUBE
228     *  is equivalent to
229     *  GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ).
230     *  Group Count: 2 ^ N (N is the number of group expressions)
231     *
232     *  We need to get all of its subsets for a given GROUPBY expression, the subsets are
233     *  represented as the bit masks.
234     */
235    def bitmasks(c: Cube): Seq[Int] = {
236      Seq.tabulate(1 << c.groupByExprs.length)(i => i)
237    }
238
239    private def hasGroupingAttribute(expr: Expression): Boolean = {
240      expr.collectFirst {
241        case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u
242      }.isDefined
243    }
244
245    private[analysis] def hasGroupingFunction(e: Expression): Boolean = {
246      e.collectFirst {
247        case g: Grouping => g
248        case g: GroupingID => g
249      }.isDefined
250    }
251
252    private def replaceGroupingFunc(
253        expr: Expression,
254        groupByExprs: Seq[Expression],
255        gid: Expression): Expression = {
256      expr transform {
257        case e: GroupingID =>
258          if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
259            gid
260          } else {
261            throw new AnalysisException(
262              s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
263                s"grouping columns (${groupByExprs.mkString(",")})")
264          }
265        case Grouping(col: Expression) =>
266          val idx = groupByExprs.indexOf(col)
267          if (idx >= 0) {
268            Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
269              Literal(1)), ByteType)
270          } else {
271            throw new AnalysisException(s"Column of grouping ($col) can't be found " +
272              s"in grouping columns ${groupByExprs.mkString(",")}")
273          }
274      }
275    }
276
277    // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
278    def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
279      case a if !a.childrenResolved => a // be sure all of the children are resolved.
280      case p if p.expressions.exists(hasGroupingAttribute) =>
281        failAnalysis(
282          s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
283
284      case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
285        GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
286      case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
287        GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
288
289      // Ensure all the expressions have been resolved.
290      case x: GroupingSets if x.expressions.forall(_.resolved) =>
291        val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
292
293        // Expand works by setting grouping expressions to null as determined by the bitmasks. To
294        // prevent these null values from being used in an aggregate instead of the original value
295        // we need to create new aliases for all group by expressions that will only be used for
296        // the intended purpose.
297        val groupByAliases: Seq[Alias] = x.groupByExprs.map {
298          case e: NamedExpression => Alias(e, e.name)()
299          case other => Alias(other, other.toString)()
300        }
301
302        // The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases
303        // with 0 indicating this expression is in the grouping set. The following line of code
304        // calculates the bitmask representing the expressions that absent in at least one grouping
305        // set (indicated by 1).
306        val nullBitmask = x.bitmasks.reduce(_ | _)
307
308        val attrLength = groupByAliases.length
309        val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
310          val canBeNull = ((nullBitmask >> (attrLength - idx - 1)) & 1) == 1
311          a.toAttribute.withNullability(a.nullable || canBeNull)
312        }
313
314        val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
315        val groupingAttrs = expand.output.drop(x.child.output.length)
316
317        val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
318          // collect all the found AggregateExpression, so we can check an expression is part of
319          // any AggregateExpression or not.
320          val aggsBuffer = ArrayBuffer[Expression]()
321          // Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
322          def isPartOfAggregation(e: Expression): Boolean = {
323            aggsBuffer.exists(a => a.find(_ eq e).isDefined)
324          }
325          replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
326            // AggregateExpression should be computed on the unmodified value of its argument
327            // expressions, so we should not replace any references to grouping expression
328            // inside it.
329            case e: AggregateExpression =>
330              aggsBuffer += e
331              e
332            case e if isPartOfAggregation(e) => e
333            case e =>
334              val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
335              if (index == -1) {
336                e
337              } else {
338                groupingAttrs(index)
339              }
340          }.asInstanceOf[NamedExpression]
341        }
342
343        Aggregate(groupingAttrs, aggregations, expand)
344
345      case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
346        val groupingExprs = findGroupingExprs(child)
347        // The unresolved grouping id will be resolved by ResolveMissingReferences
348        val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
349        f.copy(condition = newCond)
350
351      case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
352        val groupingExprs = findGroupingExprs(child)
353        val gid = VirtualColumn.groupingIdAttribute
354        // The unresolved grouping id will be resolved by ResolveMissingReferences
355        val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
356        s.copy(order = newOrder)
357    }
358
359    private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
360      plan.collectFirst {
361        case a: Aggregate =>
362          // this Aggregate should have grouping id as the last grouping key.
363          val gid = a.groupingExpressions.last
364          if (!gid.isInstanceOf[AttributeReference]
365            || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) {
366            failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
367          }
368          a.groupingExpressions.take(a.groupingExpressions.length - 1)
369      }.getOrElse {
370        failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
371      }
372    }
373  }
374
375  object ResolvePivot extends Rule[LogicalPlan] {
376    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
377      case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved)
378        | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
379      case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
380        val singleAgg = aggregates.size == 1
381        def outputName(value: Literal, aggregate: Expression): String = {
382          if (singleAgg) {
383            value.toString
384          } else {
385            val suffix = aggregate match {
386              case n: NamedExpression => n.name
387              case _ => toPrettySQL(aggregate)
388            }
389            value + "_" + suffix
390          }
391        }
392        if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
393          // Since evaluating |pivotValues| if statements for each input row can get slow this is an
394          // alternate plan that instead uses two steps of aggregation.
395          val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
396          val namedPivotCol = pivotColumn match {
397            case n: NamedExpression => n
398            case _ => Alias(pivotColumn, "__pivot_col")()
399          }
400          val bigGroup = groupByExprs :+ namedPivotCol
401          val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
402          val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow))
403          val pivotAggs = namedAggExps.map { a =>
404            Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues)
405              .toAggregateExpression()
406            , "__pivot_" + a.sql)()
407          }
408          val groupByExprsAttr = groupByExprs.map(_.toAttribute)
409          val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg)
410          val pivotAggAttribute = pivotAggs.map(_.toAttribute)
411          val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
412            aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
413              Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
414            }
415          }
416          Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
417        } else {
418          val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
419            def ifExpr(expr: Expression) = {
420              If(EqualTo(pivotColumn, value), expr, Literal(null))
421            }
422            aggregates.map { aggregate =>
423              val filteredAggregate = aggregate.transformDown {
424                // Assumption is the aggregate function ignores nulls. This is true for all current
425                // AggregateFunction's with the exception of First and Last in their default mode
426                // (which we handle) and possibly some Hive UDAF's.
427                case First(expr, _) =>
428                  First(ifExpr(expr), Literal(true))
429                case Last(expr, _) =>
430                  Last(ifExpr(expr), Literal(true))
431                case a: AggregateFunction =>
432                  a.withNewChildren(a.children.map(ifExpr))
433              }.transform {
434                // We are duplicating aggregates that are now computing a different value for each
435                // pivot value.
436                // TODO: Don't construct the physical container until after analysis.
437                case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
438              }
439              if (filteredAggregate.fastEquals(aggregate)) {
440                throw new AnalysisException(
441                  s"Aggregate expression required for pivot, found '$aggregate'")
442              }
443              Alias(filteredAggregate, outputName(value, aggregate))()
444            }
445          }
446          Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
447        }
448    }
449  }
450
451  /**
452   * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
453   */
454  object ResolveRelations extends Rule[LogicalPlan] {
455    private def lookupTableFromCatalog(u: UnresolvedRelation): LogicalPlan = {
456      try {
457        catalog.lookupRelation(u.tableIdentifier, u.alias)
458      } catch {
459        case _: NoSuchTableException =>
460          u.failAnalysis(s"Table or view not found: ${u.tableName}")
461      }
462    }
463
464    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
465      case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
466        i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u)))
467      case u: UnresolvedRelation =>
468        val table = u.tableIdentifier
469        if (table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) &&
470            (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))) {
471          // If the database part is specified, and we support running SQL directly on files, and
472          // it's not a temporary view, and the table does not exist, then let's just return the
473          // original UnresolvedRelation. It is possible we are matching a query like "select *
474          // from parquet.`/path/to/query`". The plan will get resolved later.
475          // Note that we are testing (!db_exists || !table_exists) because the catalog throws
476          // an exception from tableExists if the database does not exist.
477          u
478        } else {
479          lookupTableFromCatalog(u)
480        }
481    }
482  }
483
484  /**
485   * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
486   * a logical plan node's children.
487   */
488  object ResolveReferences extends Rule[LogicalPlan] {
489    /**
490     * Generate a new logical plan for the right child with different expression IDs
491     * for all conflicting attributes.
492     */
493    private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
494      val conflictingAttributes = left.outputSet.intersect(right.outputSet)
495      logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
496        s"between $left and $right")
497
498      right.collect {
499        // Handle base relations that might appear more than once.
500        case oldVersion: MultiInstanceRelation
501            if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
502          val newVersion = oldVersion.newInstance()
503          (oldVersion, newVersion)
504
505        case oldVersion: SerializeFromObject
506            if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
507          (oldVersion, oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance())))
508
509        // Handle projects that create conflicting aliases.
510        case oldVersion @ Project(projectList, _)
511            if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
512          (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
513
514        case oldVersion @ Aggregate(_, aggregateExpressions, _)
515            if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
516          (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
517
518        case oldVersion: Generate
519            if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
520          val newOutput = oldVersion.generatorOutput.map(_.newInstance())
521          (oldVersion, oldVersion.copy(generatorOutput = newOutput))
522
523        case oldVersion @ Window(windowExpressions, _, _, child)
524            if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
525              .nonEmpty =>
526          (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
527      }
528        // Only handle first case, others will be fixed on the next pass.
529        .headOption match {
530        case None =>
531          /*
532           * No result implies that there is a logical plan node that produces new references
533           * that this rule cannot handle. When that is the case, there must be another rule
534           * that resolves these conflicts. Otherwise, the analysis will fail.
535           */
536          right
537        case Some((oldRelation, newRelation)) =>
538          val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
539          val newRight = right transformUp {
540            case r if r == oldRelation => newRelation
541          } transformUp {
542            case other => other transformExpressions {
543              case a: Attribute =>
544                attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier)
545            }
546          }
547          newRight
548      }
549    }
550
551    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
552      case p: LogicalPlan if !p.childrenResolved => p
553
554      // If the projection list contains Stars, expand it.
555      case p: Project if containsStar(p.projectList) =>
556        p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
557      // If the aggregate function argument contains Stars, expand it.
558      case a: Aggregate if containsStar(a.aggregateExpressions) =>
559        if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
560          failAnalysis(
561            "Star (*) is not allowed in select list when GROUP BY ordinal position is used")
562        } else {
563          a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
564        }
565      // If the script transformation input contains Stars, expand it.
566      case t: ScriptTransformation if containsStar(t.input) =>
567        t.copy(
568          input = t.input.flatMap {
569            case s: Star => s.expand(t.child, resolver)
570            case o => o :: Nil
571          }
572        )
573      case g: Generate if containsStar(g.generator.children) =>
574        failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
575
576      // To resolve duplicate expression IDs for Join and Intersect
577      case j @ Join(left, right, _, _) if !j.duplicateResolved =>
578        j.copy(right = dedupRight(left, right))
579      case i @ Intersect(left, right) if !i.duplicateResolved =>
580        i.copy(right = dedupRight(left, right))
581      case i @ Except(left, right) if !i.duplicateResolved =>
582        i.copy(right = dedupRight(left, right))
583
584      // When resolve `SortOrder`s in Sort based on child, don't report errors as
585      // we still have chance to resolve it based on its descendants
586      case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
587        val newOrdering =
588          ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
589        Sort(newOrdering, global, child)
590
591      // A special case for Generate, because the output of Generate should not be resolved by
592      // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate.
593      case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g
594
595      case g @ Generate(generator, join, outer, qualifier, output, child) =>
596        val newG = resolveExpression(generator, child, throws = true)
597        if (newG.fastEquals(generator)) {
598          g
599        } else {
600          Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
601        }
602
603      // Skips plan which contains deserializer expressions, as they should be resolved by another
604      // rule: ResolveDeserializer.
605      case plan if containsDeserializer(plan.expressions) => plan
606
607      case q: LogicalPlan =>
608        logTrace(s"Attempting to resolve ${q.simpleString}")
609        q transformExpressionsUp  {
610          case u @ UnresolvedAttribute(nameParts) =>
611            // Leave unchanged if resolution fails.  Hopefully will be resolved next round.
612            val result =
613              withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
614            logDebug(s"Resolving $u to $result")
615            result
616          case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
617            ExtractValue(child, fieldExpr, resolver)
618        }
619    }
620
621    def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
622      expressions.map {
623        case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
624        case other => other
625      }
626    }
627
628    def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
629      AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
630    }
631
632    /**
633     * Build a project list for Project/Aggregate and expand the star if possible
634     */
635    private def buildExpandedProjectList(
636      exprs: Seq[NamedExpression],
637      child: LogicalPlan): Seq[NamedExpression] = {
638      exprs.flatMap {
639        // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
640        case s: Star => s.expand(child, resolver)
641        // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
642        case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
643        case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
644        case o => o :: Nil
645      }.map(_.asInstanceOf[NamedExpression])
646    }
647
648    /**
649     * Returns true if `exprs` contains a [[Star]].
650     */
651    def containsStar(exprs: Seq[Expression]): Boolean =
652      exprs.exists(_.collect { case _: Star => true }.nonEmpty)
653
654    /**
655     * Expands the matching attribute.*'s in `child`'s output.
656     */
657    def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
658      expr.transformUp {
659        case f1: UnresolvedFunction if containsStar(f1.children) =>
660          f1.copy(children = f1.children.flatMap {
661            case s: Star => s.expand(child, resolver)
662            case o => o :: Nil
663          })
664        case c: CreateNamedStruct if containsStar(c.valExprs) =>
665          val newChildren = c.children.grouped(2).flatMap {
666            case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children
667            case kv => kv
668          }
669          c.copy(children = newChildren.toList )
670        case c: CreateArray if containsStar(c.children) =>
671          c.copy(children = c.children.flatMap {
672            case s: Star => s.expand(child, resolver)
673            case o => o :: Nil
674          })
675        case p: Murmur3Hash if containsStar(p.children) =>
676          p.copy(children = p.children.flatMap {
677            case s: Star => s.expand(child, resolver)
678            case o => o :: Nil
679          })
680        // count(*) has been replaced by count(1)
681        case o if containsStar(o.children) =>
682          failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
683      }
684    }
685  }
686
687  private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
688    exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
689  }
690
691  protected[sql] def resolveExpression(
692      expr: Expression,
693      plan: LogicalPlan,
694      throws: Boolean = false) = {
695    // Resolve expression in one round.
696    // If throws == false or the desired attribute doesn't exist
697    // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
698    // Else, throw exception.
699    try {
700      expr transformUp {
701        case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
702        case u @ UnresolvedAttribute(nameParts) =>
703          withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
704        case UnresolvedExtractValue(child, fieldName) if child.resolved =>
705          ExtractValue(child, fieldName, resolver)
706      }
707    } catch {
708      case a: AnalysisException if !throws => expr
709    }
710  }
711
712 /**
713  * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
714  * clauses. This rule is to convert ordinal positions to the corresponding expressions in the
715  * select list. This support is introduced in Spark 2.0.
716  *
717  * - When the sort references or group by expressions are not integer but foldable expressions,
718  * just ignore them.
719  * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position
720  * numbers too.
721  *
722  * Before the release of Spark 2.0, the literals in order/sort by and group by clauses
723  * have no effect on the results.
724  */
725  object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
726    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
727      case p if !p.childrenResolved => p
728      // Replace the index with the related attribute for ORDER BY,
729      // which is a 1-base position of the projection list.
730      case Sort(orders, global, child)
731        if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
732        val newOrders = orders map {
733          case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) =>
734            if (index > 0 && index <= child.output.size) {
735              SortOrder(child.output(index - 1), direction, nullOrdering)
736            } else {
737              s.failAnalysis(
738                s"ORDER BY position $index is not in select list " +
739                  s"(valid range is [1, ${child.output.size}])")
740            }
741          case o => o
742        }
743        Sort(newOrders, global, child)
744
745      // Replace the index with the corresponding expression in aggregateExpressions. The index is
746      // a 1-base position of aggregateExpressions, which is output columns (select expression)
747      case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
748        groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
749        val newGroups = groups.map {
750          case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
751            aggs(index - 1)
752          case ordinal @ UnresolvedOrdinal(index) =>
753            ordinal.failAnalysis(
754              s"GROUP BY position $index is not in select list " +
755                s"(valid range is [1, ${aggs.size}])")
756          case o => o
757        }
758        Aggregate(newGroups, aggs, child)
759    }
760  }
761
762  /**
763   * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
764   * clause.  This rule detects such queries and adds the required attributes to the original
765   * projection, so that they will be available during sorting. Another projection is added to
766   * remove these attributes after sorting.
767   *
768   * The HAVING clause could also used a grouping columns that is not presented in the SELECT.
769   */
770  object ResolveMissingReferences extends Rule[LogicalPlan] {
771    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
772      // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
773      case sa @ Sort(_, _, child: Aggregate) => sa
774
775      case s @ Sort(order, _, child) if child.resolved =>
776        try {
777          val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
778          val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
779          val missingAttrs = requiredAttrs -- child.outputSet
780          if (missingAttrs.nonEmpty) {
781            // Add missing attributes and then project them away after the sort.
782            Project(child.output,
783              Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
784          } else if (newOrder != order) {
785            s.copy(order = newOrder)
786          } else {
787            s
788          }
789        } catch {
790          // Attempting to resolve it might fail. When this happens, return the original plan.
791          // Users will see an AnalysisException for resolution failure of missing attributes
792          // in Sort
793          case ae: AnalysisException => s
794        }
795
796      case f @ Filter(cond, child) if child.resolved =>
797        try {
798          val newCond = resolveExpressionRecursively(cond, child)
799          val requiredAttrs = newCond.references.filter(_.resolved)
800          val missingAttrs = requiredAttrs -- child.outputSet
801          if (missingAttrs.nonEmpty) {
802            // Add missing attributes and then project them away.
803            Project(child.output,
804              Filter(newCond, addMissingAttr(child, missingAttrs)))
805          } else if (newCond != cond) {
806            f.copy(condition = newCond)
807          } else {
808            f
809          }
810        } catch {
811          // Attempting to resolve it might fail. When this happens, return the original plan.
812          // Users will see an AnalysisException for resolution failure of missing attributes
813          case ae: AnalysisException => f
814        }
815    }
816
817    /**
818     * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
819     * Aggregate.
820     */
821    private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
822      if (missingAttrs.isEmpty) {
823        return plan
824      }
825      plan match {
826        case p: Project =>
827          val missing = missingAttrs -- p.child.outputSet
828          Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
829        case a: Aggregate =>
830          // all the missing attributes should be grouping expressions
831          // TODO: push down AggregateExpression
832          missingAttrs.foreach { attr =>
833            if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
834              throw new AnalysisException(s"Can't add $attr to ${a.simpleString}")
835            }
836          }
837          val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
838          a.copy(aggregateExpressions = newAggregateExpressions)
839        case g: Generate =>
840          // If join is false, we will convert it to true for getting from the child the missing
841          // attributes that its child might have or could have.
842          val missing = missingAttrs -- g.child.outputSet
843          g.copy(join = true, child = addMissingAttr(g.child, missing))
844        case d: Distinct =>
845          throw new AnalysisException(s"Can't add $missingAttrs to $d")
846        case u: UnaryNode =>
847          u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
848        case other =>
849          throw new AnalysisException(s"Can't add $missingAttrs to $other")
850      }
851    }
852
853    /**
854     * Resolve the expression on a specified logical plan and it's child (recursively), until
855     * the expression is resolved or meet a non-unary node or Subquery.
856     */
857    @tailrec
858    private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
859      val resolved = resolveExpression(expr, plan)
860      if (resolved.resolved) {
861        resolved
862      } else {
863        plan match {
864          case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] =>
865            resolveExpressionRecursively(resolved, u.child)
866          case other => resolved
867        }
868      }
869    }
870  }
871
872  /**
873   * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
874   */
875  object ResolveFunctions extends Rule[LogicalPlan] {
876    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
877      case q: LogicalPlan =>
878        q transformExpressions {
879          case u if !u.childrenResolved => u // Skip until children are resolved.
880          case u @ UnresolvedGenerator(name, children) =>
881            withPosition(u) {
882              catalog.lookupFunction(name, children) match {
883                case generator: Generator => generator
884                case other =>
885                  failAnalysis(s"$name is expected to be a generator. However, " +
886                    s"its class is ${other.getClass.getCanonicalName}, which is not a generator.")
887              }
888            }
889          case u @ UnresolvedFunction(funcId, children, isDistinct) =>
890            withPosition(u) {
891              catalog.lookupFunction(funcId, children) match {
892                // DISTINCT is not meaningful for a Max or a Min.
893                case max: Max if isDistinct =>
894                  AggregateExpression(max, Complete, isDistinct = false)
895                case min: Min if isDistinct =>
896                  AggregateExpression(min, Complete, isDistinct = false)
897                // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
898                // the context of a Window clause. They do not need to be wrapped in an
899                // AggregateExpression.
900                case wf: AggregateWindowFunction => wf
901                // We get an aggregate function, we need to wrap it in an AggregateExpression.
902                case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
903                // This function is not an aggregate function, just return the resolved one.
904                case other => other
905              }
906            }
907        }
908    }
909  }
910
911  /**
912   * This rule resolves and rewrites subqueries inside expressions.
913   *
914   * Note: CTEs are handled in CTESubstitution.
915   */
916  object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
917    /**
918     * Resolve the correlated expressions in a subquery by using the an outer plans' references. All
919     * resolved outer references are wrapped in an [[OuterReference]]
920     */
921    private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
922      plan transformDown {
923        case q: LogicalPlan if q.childrenResolved && !q.resolved =>
924          q transformExpressions {
925            case u @ UnresolvedAttribute(nameParts) =>
926              withPosition(u) {
927                try {
928                  outer.resolve(nameParts, resolver) match {
929                    case Some(outerAttr) => OuterReference(outerAttr)
930                    case None => u
931                  }
932                } catch {
933                  case _: AnalysisException => u
934                }
935              }
936          }
937      }
938    }
939
940    /**
941     * Pull out all (outer) correlated predicates from a given subquery. This method removes the
942     * correlated predicates from subquery [[Filter]]s and adds the references of these predicates
943     * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to
944     * be able to evaluate the predicates at the top level.
945     *
946     * This method returns the rewritten subquery and correlated predicates.
947     */
948    private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
949      val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
950
951      // Make sure a plan's subtree does not contain outer references
952      def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = {
953        if (p.collectFirst(predicateMap).nonEmpty) {
954          failAnalysis(s"Accessing outer query column is not allowed in:\n$p")
955        }
956      }
957
958      // Helper function for locating outer references.
959      def containsOuter(e: Expression): Boolean = {
960        e.find(_.isInstanceOf[OuterReference]).isDefined
961      }
962
963      // Make sure a plan's expressions do not contain outer references
964      def failOnOuterReference(p: LogicalPlan): Unit = {
965        if (p.expressions.exists(containsOuter)) {
966          failAnalysis(
967            "Expressions referencing the outer query are not supported outside of WHERE/HAVING " +
968              s"clauses:\n$p")
969        }
970      }
971
972      // SPARK-17348: A potential incorrect result case.
973      // When a correlated predicate is a non-equality predicate,
974      // certain operators are not permitted from the operator
975      // hosting the correlated predicate up to the operator on the outer table.
976      // Otherwise, the pull up of the correlated predicate
977      // will generate a plan with a different semantics
978      // which could return incorrect result.
979      // Currently we check for Aggregate and Window operators
980      //
981      // Below shows an example of a Logical Plan during Analyzer phase that
982      // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..]
983      // through the Aggregate (or Window) operator could alter the result of
984      // the Aggregate.
985      //
986      // Project [c1#76]
987      // +- Project [c1#87, c2#88]
988      // :  (Aggregate or Window operator)
989      // :  +- Filter [outer(c2#77) >= c2#88)]
990      // :     +- SubqueryAlias t2, `t2`
991      // :        +- Project [_1#84 AS c1#87, _2#85 AS c2#88]
992      // :           +- LocalRelation [_1#84, _2#85]
993      // +- SubqueryAlias t1, `t1`
994      // +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
995      // +- LocalRelation [_1#73, _2#74]
996      def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
997        if (found) {
998          // Report a non-supported case as an exception
999          failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
1000        }
1001      }
1002
1003      /** Determine which correlated predicate references are missing from this plan. */
1004      def missingReferences(p: LogicalPlan): AttributeSet = {
1005        val localPredicateReferences = p.collect(predicateMap)
1006          .flatten
1007          .map(_.references)
1008          .reduceOption(_ ++ _)
1009          .getOrElse(AttributeSet.empty)
1010        localPredicateReferences -- p.outputSet
1011      }
1012
1013      var foundNonEqualCorrelatedPred : Boolean = false
1014
1015      // Simplify the predicates before pulling them out.
1016      val transformed = BooleanSimplification(sub) transformUp {
1017
1018        // Whitelist operators allowed in a correlated subquery
1019        // There are 4 categories:
1020        // 1. Operators that are allowed anywhere in a correlated subquery, and,
1021        //    by definition of the operators, they either do not contain
1022        //    any columns or cannot host outer references.
1023        // 2. Operators that are allowed anywhere in a correlated subquery
1024        //    so long as they do not host outer references.
1025        // 3. Operators that need special handlings. These operators are
1026        //    Project, Filter, Join, Aggregate, and Generate.
1027        //
1028        // Any operators that are not in the above list are allowed
1029        // in a correlated subquery only if they are not on a correlation path.
1030        // In other word, these operators are allowed only under a correlation point.
1031        //
1032        // A correlation path is defined as the sub-tree of all the operators that
1033        // are on the path from the operator hosting the correlated expressions
1034        // up to the operator producing the correlated values.
1035
1036        // Category 1:
1037        // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias
1038        case p: BroadcastHint =>
1039          p
1040        case p: Distinct =>
1041          p
1042        case p: LeafNode =>
1043          p
1044        case p: Repartition =>
1045          p
1046        case p: SubqueryAlias =>
1047          p
1048
1049        // Category 2:
1050        // These operators can be anywhere in a correlated subquery.
1051        // so long as they do not host outer references in the operators.
1052        case p: Sort =>
1053          failOnOuterReference(p)
1054          p
1055        case p: RepartitionByExpression =>
1056          failOnOuterReference(p)
1057          p
1058
1059        // Category 3:
1060        // Filter is one of the two operators allowed to host correlated expressions.
1061        // The other operator is Join. Filter can be anywhere in a correlated subquery.
1062        case f @ Filter(cond, child) =>
1063          // Find all predicates with an outer reference.
1064          val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter)
1065
1066          // Find any non-equality correlated predicates
1067          foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
1068            case _: EqualTo | _: EqualNullSafe => false
1069            case _ => true
1070          }
1071
1072          // Rewrite the filter without the correlated predicates if any.
1073          correlated match {
1074            case Nil => f
1075            case xs if local.nonEmpty =>
1076              val newFilter = Filter(local.reduce(And), child)
1077              predicateMap += newFilter -> xs
1078              newFilter
1079            case xs =>
1080              predicateMap += child -> xs
1081              child
1082          }
1083
1084        // Project cannot host any correlated expressions
1085        // but can be anywhere in a correlated subquery.
1086        case p @ Project(expressions, child) =>
1087          failOnOuterReference(p)
1088
1089          val referencesToAdd = missingReferences(p)
1090          if (referencesToAdd.nonEmpty) {
1091            Project(expressions ++ referencesToAdd, child)
1092          } else {
1093            p
1094          }
1095
1096        // Aggregate cannot host any correlated expressions
1097        // It can be on a correlation path if the correlation contains
1098        // only equality correlated predicates.
1099        // It cannot be on a correlation path if the correlation has
1100        // non-equality correlated predicates.
1101        case a @ Aggregate(grouping, expressions, child) =>
1102          failOnOuterReference(a)
1103          failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
1104
1105          val referencesToAdd = missingReferences(a)
1106          if (referencesToAdd.nonEmpty) {
1107            Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
1108          } else {
1109            a
1110          }
1111
1112        // Join can host correlated expressions.
1113        case j @ Join(left, right, joinType, _) =>
1114          joinType match {
1115            // Inner join, like Filter, can be anywhere.
1116            case _: InnerLike =>
1117              failOnOuterReference(j)
1118
1119            // Left outer join's right operand cannot be on a correlation path.
1120            // LeftAnti and ExistenceJoin are special cases of LeftOuter.
1121            // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame
1122            // so it should not show up here in Analysis phase. This is just a safety net.
1123            //
1124            // LeftSemi does not allow output from the right operand.
1125            // Any correlated references in the subplan
1126            // of the right operand cannot be pulled up.
1127            case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
1128              failOnOuterReference(j)
1129              failOnOuterReferenceInSubTree(right)
1130
1131            // Likewise, Right outer join's left operand cannot be on a correlation path.
1132            case RightOuter =>
1133              failOnOuterReference(j)
1134              failOnOuterReferenceInSubTree(left)
1135
1136            // Any other join types not explicitly listed above,
1137            // including Full outer join, are treated as Category 4.
1138            case _ =>
1139              failOnOuterReferenceInSubTree(j)
1140          }
1141          j
1142
1143        // Generator with join=true, i.e., expressed with
1144        // LATERAL VIEW [OUTER], similar to inner join,
1145        // allows to have correlation under it
1146        // but must not host any outer references.
1147        // Note:
1148        // Generator with join=false is treated as Category 4.
1149        case p @ Generate(generator, true, _, _, _, _) =>
1150          failOnOuterReference(p)
1151          p
1152
1153        // Category 4: Any other operators not in the above 3 categories
1154        // cannot be on a correlation path, that is they are allowed only
1155        // under a correlation point but they and their descendant operators
1156        // are not allowed to have any correlated expressions.
1157        case p =>
1158          failOnOuterReferenceInSubTree(p)
1159          p
1160      }
1161      (transformed, predicateMap.values.flatten.toSeq)
1162    }
1163
1164    /**
1165     * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same
1166     * attributes.
1167     */
1168    private def rewriteSubQuery(
1169        sub: LogicalPlan,
1170        outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
1171      // Pull out the tagged predicates and rewrite the subquery in the process.
1172      val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub)
1173
1174      // Make sure the inner and the outer query attributes do not collide.
1175      val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
1176      val duplicates = basePlan.outputSet.intersect(outputSet)
1177      val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
1178        val aliasMap = AttributeMap(duplicates.map { dup =>
1179          dup -> Alias(dup, dup.toString)()
1180        }.toSeq)
1181        val aliasedExpressions = basePlan.output.map { ref =>
1182          aliasMap.getOrElse(ref, ref)
1183        }
1184        val aliasedProjection = Project(aliasedExpressions, basePlan)
1185        val aliasedConditions = baseConditions.map(_.transform {
1186          case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
1187        })
1188        (aliasedProjection, aliasedConditions)
1189      } else {
1190        (basePlan, baseConditions)
1191      }
1192      // Remove outer references from the correlated predicates. We wait with extracting
1193      // these until collisions between the inner and outer query attributes have been
1194      // solved.
1195      val conditions = deDuplicatedConditions.map(_.transform {
1196        case OuterReference(ref) => ref
1197      })
1198      (plan, conditions)
1199    }
1200
1201    /**
1202     * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method
1203     * will resolve the subquery by alternating between the regular analyzer and by applying the
1204     * resolveOuterReferences rule.
1205     *
1206     * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved.
1207     */
1208    private def resolveSubQuery(
1209        e: SubqueryExpression,
1210        plans: Seq[LogicalPlan],
1211        requiredColumns: Int = 0)(
1212        f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
1213      // Step 1: Resolve the outer expressions.
1214      var previous: LogicalPlan = null
1215      var current = e.plan
1216      do {
1217        // Try to resolve the subquery plan using the regular analyzer.
1218        previous = current
1219        current = execute(current)
1220
1221        // Use the outer references to resolve the subquery plan if it isn't resolved yet.
1222        val i = plans.iterator
1223        val afterResolve = current
1224        while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) {
1225          current = resolveOuterReferences(current, i.next())
1226        }
1227      } while (!current.resolved && !current.fastEquals(previous))
1228
1229      // Step 2: Pull out the predicates if the plan is resolved.
1230      if (current.resolved) {
1231        // Make sure the resolved query has the required number of output columns. This is only
1232        // needed for Scalar and IN subqueries.
1233        if (requiredColumns > 0 && requiredColumns != current.output.size) {
1234          failAnalysis(s"The number of columns in the subquery (${current.output.size}) " +
1235            s"does not match the required number of columns ($requiredColumns)")
1236        }
1237        // Pullout predicates and construct a new plan.
1238        f.tupled(rewriteSubQuery(current, plans))
1239      } else {
1240        e.withNewPlan(current)
1241      }
1242    }
1243
1244    /**
1245     * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS
1246     * expressions into PredicateSubquery expression once the are resolved.
1247     */
1248    private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
1249      plan transformExpressions {
1250        case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
1251          resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
1252        case e @ Exists(sub, exprId) =>
1253          resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId))
1254        case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
1255          // Get the left hand side expressions.
1256          val expressions = e match {
1257            case cns : CreateNamedStruct => cns.valExprs
1258            case expr => Seq(expr)
1259          }
1260          resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) =>
1261            // Construct the IN conditions.
1262            val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled)
1263            PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId)
1264          }
1265      }
1266    }
1267
1268    /**
1269     * Resolve and rewrite all subqueries in an operator tree..
1270     */
1271    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1272      // In case of HAVING (a filter after an aggregate) we use both the aggregate and
1273      // its child for resolution.
1274      case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
1275        resolveSubQueries(f, Seq(a, a.child))
1276      // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
1277      case q: UnaryNode if q.childrenResolved =>
1278        resolveSubQueries(q, q.children)
1279    }
1280  }
1281
1282  /**
1283   * Turns projections that contain aggregate expressions into aggregations.
1284   */
1285  object GlobalAggregates extends Rule[LogicalPlan] {
1286    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1287      case Project(projectList, child) if containsAggregates(projectList) =>
1288        Aggregate(Nil, projectList, child)
1289    }
1290
1291    def containsAggregates(exprs: Seq[Expression]): Boolean = {
1292      // Collect all Windowed Aggregate Expressions.
1293      val windowedAggExprs = exprs.flatMap { expr =>
1294        expr.collect {
1295          case WindowExpression(ae: AggregateExpression, _) => ae
1296        }
1297      }.toSet
1298
1299      // Find the first Aggregate Expression that is not Windowed.
1300      exprs.exists(_.collectFirst {
1301        case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae
1302      }.isDefined)
1303    }
1304  }
1305
1306  /**
1307   * This rule finds aggregate expressions that are not in an aggregate operator.  For example,
1308   * those in a HAVING clause or ORDER BY clause.  These expressions are pushed down to the
1309   * underlying aggregate operator and then projected away after the original operator.
1310   */
1311  object ResolveAggregateFunctions extends Rule[LogicalPlan] {
1312    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1313      case filter @ Filter(havingCondition,
1314             aggregate @ Aggregate(grouping, originalAggExprs, child))
1315          if aggregate.resolved =>
1316
1317        // Try resolving the condition of the filter as though it is in the aggregate clause
1318        try {
1319          val aggregatedCondition =
1320            Aggregate(
1321              grouping,
1322              Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
1323              child)
1324          val resolvedOperator = execute(aggregatedCondition)
1325          def resolvedAggregateFilter =
1326            resolvedOperator
1327              .asInstanceOf[Aggregate]
1328              .aggregateExpressions.head
1329
1330          // If resolution was successful and we see the filter has an aggregate in it, add it to
1331          // the original aggregate operator.
1332          if (resolvedOperator.resolved) {
1333            // Try to replace all aggregate expressions in the filter by an alias.
1334            val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
1335            val transformedAggregateFilter = resolvedAggregateFilter.transform {
1336              case ae: AggregateExpression =>
1337                val alias = Alias(ae, ae.toString)()
1338                aggregateExpressions += alias
1339                alias.toAttribute
1340              // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
1341              case e: Expression if grouping.exists(_.semanticEquals(e)) &&
1342                  !ResolveGroupingAnalytics.hasGroupingFunction(e) &&
1343                  !aggregate.output.exists(_.semanticEquals(e)) =>
1344                e match {
1345                  case ne: NamedExpression =>
1346                    aggregateExpressions += ne
1347                    ne.toAttribute
1348                  case _ =>
1349                    val alias = Alias(e, e.toString)()
1350                    aggregateExpressions += alias
1351                    alias.toAttribute
1352                }
1353            }
1354
1355            // Push the aggregate expressions into the aggregate (if any).
1356            if (aggregateExpressions.nonEmpty) {
1357              Project(aggregate.output,
1358                Filter(transformedAggregateFilter,
1359                  aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
1360            } else {
1361              filter
1362            }
1363          } else {
1364            filter
1365          }
1366        } catch {
1367          // Attempting to resolve in the aggregate can result in ambiguity.  When this happens,
1368          // just return the original plan.
1369          case ae: AnalysisException => filter
1370        }
1371
1372      case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
1373
1374        // Try resolving the ordering as though it is in the aggregate clause.
1375        try {
1376          val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
1377          val aliasedOrdering =
1378            unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true))
1379          val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
1380          val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
1381          val resolvedAliasedOrdering: Seq[Alias] =
1382            resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
1383
1384          // If we pass the analysis check, then the ordering expressions should only reference to
1385          // aggregate expressions or grouping expressions, and it's safe to push them down to
1386          // Aggregate.
1387          checkAnalysis(resolvedAggregate)
1388
1389          val originalAggExprs = aggregate.aggregateExpressions.map(
1390            CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
1391
1392          // If the ordering expression is same with original aggregate expression, we don't need
1393          // to push down this ordering expression and can reference the original aggregate
1394          // expression instead.
1395          val needsPushDown = ArrayBuffer.empty[NamedExpression]
1396          val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
1397            case (evaluated, order) =>
1398              val index = originalAggExprs.indexWhere {
1399                case Alias(child, _) => child semanticEquals evaluated.child
1400                case other => other semanticEquals evaluated.child
1401              }
1402
1403              if (index == -1) {
1404                needsPushDown += evaluated
1405                order.copy(child = evaluated.toAttribute)
1406              } else {
1407                order.copy(child = originalAggExprs(index).toAttribute)
1408              }
1409          }
1410
1411          val sortOrdersMap = unresolvedSortOrders
1412            .map(new TreeNodeRef(_))
1413            .zip(evaluatedOrderings)
1414            .toMap
1415          val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s))
1416
1417          // Since we don't rely on sort.resolved as the stop condition for this rule,
1418          // we need to check this and prevent applying this rule multiple times
1419          if (sortOrder == finalSortOrders) {
1420            sort
1421          } else {
1422            Project(aggregate.output,
1423              Sort(finalSortOrders, global,
1424                aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
1425          }
1426        } catch {
1427          // Attempting to resolve in the aggregate can result in ambiguity.  When this happens,
1428          // just return the original plan.
1429          case ae: AnalysisException => sort
1430        }
1431    }
1432
1433    def containsAggregate(condition: Expression): Boolean = {
1434      condition.find(_.isInstanceOf[AggregateExpression]).isDefined
1435    }
1436  }
1437
1438  /**
1439   * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]]
1440   * operator under [[Project]].
1441   *
1442   * This rule will throw [[AnalysisException]] for following cases:
1443   * 1. [[Generator]] is nested in expressions, e.g. `SELECT explode(list) + 1 FROM tbl`
1444   * 2. more than one [[Generator]] is found in projectList,
1445   *    e.g. `SELECT explode(list), explode(list) FROM tbl`
1446   * 3. [[Generator]] is found in other operators that are not [[Project]] or [[Generate]],
1447   *    e.g. `SELECT * FROM tbl SORT BY explode(list)`
1448   */
1449  object ExtractGenerator extends Rule[LogicalPlan] {
1450    private def hasGenerator(expr: Expression): Boolean = {
1451      expr.find(_.isInstanceOf[Generator]).isDefined
1452    }
1453
1454    private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match {
1455      case UnresolvedAlias(_: Generator, _) => false
1456      case Alias(_: Generator, _) => false
1457      case MultiAlias(_: Generator, _) => false
1458      case other => hasGenerator(other)
1459    }
1460
1461    private def trimAlias(expr: NamedExpression): Expression = expr match {
1462      case UnresolvedAlias(child, _) => child
1463      case Alias(child, _) => child
1464      case MultiAlias(child, _) => child
1465      case _ => expr
1466    }
1467
1468    /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
1469    private object AliasedGenerator {
1470      def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
1471        case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
1472        case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
1473        case _ => None
1474      }
1475    }
1476
1477    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1478      case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
1479        val nestedGenerator = projectList.find(hasNestedGenerator).get
1480        throw new AnalysisException("Generators are not supported when it's nested in " +
1481          "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator)))
1482
1483      case Project(projectList, _) if projectList.count(hasGenerator) > 1 =>
1484        val generators = projectList.filter(hasGenerator).map(trimAlias)
1485        throw new AnalysisException("Only one generator allowed per select clause but found " +
1486          generators.size + ": " + generators.map(toPrettySQL).mkString(", "))
1487
1488      case p @ Project(projectList, child) =>
1489        // Holds the resolved generator, if one exists in the project list.
1490        var resolvedGenerator: Generate = null
1491
1492        val newProjectList = projectList.flatMap {
1493          case AliasedGenerator(generator, names) if generator.childrenResolved =>
1494            // It's a sanity check, this should not happen as the previous case will throw
1495            // exception earlier.
1496            assert(resolvedGenerator == null, "More than one generator found in SELECT.")
1497
1498            resolvedGenerator =
1499              Generate(
1500                generator,
1501                join = projectList.size > 1, // Only join if there are other expressions in SELECT.
1502                outer = false,
1503                qualifier = None,
1504                generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
1505                child)
1506
1507            resolvedGenerator.generatorOutput
1508          case other => other :: Nil
1509        }
1510
1511        if (resolvedGenerator != null) {
1512          Project(newProjectList, resolvedGenerator)
1513        } else {
1514          p
1515        }
1516
1517      case g: Generate => g
1518
1519      case p if p.expressions.exists(hasGenerator) =>
1520        throw new AnalysisException("Generators are not supported outside the SELECT clause, but " +
1521          "got: " + p.simpleString)
1522    }
1523  }
1524
1525  /**
1526   * Rewrites table generating expressions that either need one or more of the following in order
1527   * to be resolved:
1528   *  - concrete attribute references for their output.
1529   *  - to be relocated from a SELECT clause (i.e. from  a [[Project]]) into a [[Generate]]).
1530   *
1531   * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions
1532   * that wrap the [[Generator]].
1533   */
1534  object ResolveGenerate extends Rule[LogicalPlan] {
1535    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1536      case g: Generate if !g.child.resolved || !g.generator.resolved => g
1537      case g: Generate if !g.resolved =>
1538        g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
1539    }
1540
1541    /**
1542     * Construct the output attributes for a [[Generator]], given a list of names.  If the list of
1543     * names is empty names are assigned from field names in generator.
1544     */
1545    private[analysis] def makeGeneratorOutput(
1546        generator: Generator,
1547        names: Seq[String]): Seq[Attribute] = {
1548      val elementAttrs = generator.elementSchema.toAttributes
1549
1550      if (names.length == elementAttrs.length) {
1551        names.zip(elementAttrs).map {
1552          case (name, attr) => attr.withName(name)
1553        }
1554      } else if (names.isEmpty) {
1555        elementAttrs
1556      } else {
1557        failAnalysis(
1558          "The number of aliases supplied in the AS clause does not match the number of columns " +
1559          s"output by the UDTF expected ${elementAttrs.size} aliases but got " +
1560          s"${names.mkString(",")} ")
1561      }
1562    }
1563  }
1564
1565  /**
1566   * Fixes nullability of Attributes in a resolved LogicalPlan by using the nullability of
1567   * corresponding Attributes of its children output Attributes. This step is needed because
1568   * users can use a resolved AttributeReference in the Dataset API and outer joins
1569   * can change the nullability of an AttribtueReference. Without the fix, a nullable column's
1570   * nullable field can be actually set as non-nullable, which cause illegal optimization
1571   * (e.g., NULL propagation) and wrong answers.
1572   * See SPARK-13484 and SPARK-13801 for the concrete queries of this case.
1573   */
1574  object FixNullability extends Rule[LogicalPlan] {
1575
1576    def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
1577      case p if !p.resolved => p // Skip unresolved nodes.
1578      case p: LogicalPlan if p.resolved =>
1579        val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap {
1580          case (exprId, attributes) =>
1581            // If there are multiple Attributes having the same ExprId, we need to resolve
1582            // the conflict of nullable field. We do not really expect this happen.
1583            val nullable = attributes.exists(_.nullable)
1584            attributes.map(attr => attr.withNullability(nullable))
1585        }.toSeq
1586        // At here, we create an AttributeMap that only compare the exprId for the lookup
1587        // operation. So, we can find the corresponding input attribute's nullability.
1588        val attributeMap = AttributeMap[Attribute](childrenOutput.map(attr => attr -> attr))
1589        // For an Attribute used by the current LogicalPlan, if it is from its children,
1590        // we fix the nullable field by using the nullability setting of the corresponding
1591        // output Attribute from the children.
1592        p.transformExpressions {
1593          case attr: Attribute if attributeMap.contains(attr) =>
1594            attr.withNullability(attributeMap(attr).nullable)
1595        }
1596    }
1597  }
1598
1599  /**
1600   * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
1601   * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
1602   * operators for every distinct [[WindowSpecDefinition]].
1603   *
1604   * This rule handles three cases:
1605   *  - A [[Project]] having [[WindowExpression]]s in its projectList;
1606   *  - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions.
1607   *  - A [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING
1608   *    clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions.
1609   * Note: If there is a GROUP BY clause in the query, aggregations and corresponding
1610   * filters (expressions in the HAVING clause) should be evaluated before any
1611   * [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be
1612   * evaluated after all [[WindowExpression]]s.
1613   *
1614   * For every case, the transformation works as follows:
1615   * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
1616   *    it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
1617   *    all regular expressions.
1618   * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
1619   * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
1620   *    it into the plan tree.
1621   */
1622  object ExtractWindowExpressions extends Rule[LogicalPlan] {
1623    private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
1624      projectList.exists(hasWindowFunction)
1625
1626    private def hasWindowFunction(expr: NamedExpression): Boolean = {
1627      expr.find {
1628        case window: WindowExpression => true
1629        case _ => false
1630      }.isDefined
1631    }
1632
1633    /**
1634     * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and
1635     * other regular expressions that do not contain any window expression. For example, for
1636     * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract
1637     * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in
1638     * the window expression as attribute references. So, the first returned value will be
1639     * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be
1640     * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2].
1641     *
1642     * @return (seq of expressions containing at lease one window expressions,
1643     *          seq of non-window expressions)
1644     */
1645    private def extract(
1646        expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
1647      // First, we partition the input expressions to two part. For the first part,
1648      // every expression in it contain at least one WindowExpression.
1649      // Expressions in the second part do not have any WindowExpression.
1650      val (expressionsWithWindowFunctions, regularExpressions) =
1651        expressions.partition(hasWindowFunction)
1652
1653      // Then, we need to extract those regular expressions used in the WindowExpression.
1654      // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
1655      // we need to make sure that col1 to col5 are all projected from the child of the Window
1656      // operator.
1657      val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
1658      def extractExpr(expr: Expression): Expression = expr match {
1659        case ne: NamedExpression =>
1660          // If a named expression is not in regularExpressions, add it to
1661          // extractedExprBuffer and replace it with an AttributeReference.
1662          val missingExpr =
1663            AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
1664          if (missingExpr.nonEmpty) {
1665            extractedExprBuffer += ne
1666          }
1667          // alias will be cleaned in the rule CleanupAliases
1668          ne
1669        case e: Expression if e.foldable =>
1670          e // No need to create an attribute reference if it will be evaluated as a Literal.
1671        case e: Expression =>
1672          // For other expressions, we extract it and replace it with an AttributeReference (with
1673          // an internal column name, e.g. "_w0").
1674          val withName = Alias(e, s"_w${extractedExprBuffer.length}")()
1675          extractedExprBuffer += withName
1676          withName.toAttribute
1677      }
1678
1679      // Now, we extract regular expressions from expressionsWithWindowFunctions
1680      // by using extractExpr.
1681      val seenWindowAggregates = new ArrayBuffer[AggregateExpression]
1682      val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
1683        _.transform {
1684          // Extracts children expressions of a WindowFunction (input parameters of
1685          // a WindowFunction).
1686          case wf: WindowFunction =>
1687            val newChildren = wf.children.map(extractExpr)
1688            wf.withNewChildren(newChildren)
1689
1690          // Extracts expressions from the partition spec and order spec.
1691          case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) =>
1692            val newPartitionSpec = partitionSpec.map(extractExpr)
1693            val newOrderSpec = orderSpec.map { so =>
1694              val newChild = extractExpr(so.child)
1695              so.copy(child = newChild)
1696            }
1697            wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)
1698
1699          // Extract Windowed AggregateExpression
1700          case we @ WindowExpression(
1701              ae @ AggregateExpression(function, _, _, _),
1702              spec: WindowSpecDefinition) =>
1703            val newChildren = function.children.map(extractExpr)
1704            val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
1705            val newAgg = ae.copy(aggregateFunction = newFunction)
1706            seenWindowAggregates += newAgg
1707            WindowExpression(newAgg, spec)
1708
1709          // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
1710          // we need to extract SUM(x).
1711          case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>
1712            val withName = Alias(agg, s"_w${extractedExprBuffer.length}")()
1713            extractedExprBuffer += withName
1714            withName.toAttribute
1715
1716          // Extracts other attributes
1717          case attr: Attribute => extractExpr(attr)
1718
1719        }.asInstanceOf[NamedExpression]
1720      }
1721
1722      (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer)
1723    } // end of extract
1724
1725    /**
1726     * Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
1727     */
1728    private def addWindow(
1729        expressionsWithWindowFunctions: Seq[NamedExpression],
1730        child: LogicalPlan): LogicalPlan = {
1731      // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions
1732      // and put those extracted WindowExpressions to extractedWindowExprBuffer.
1733      // This step is needed because it is possible that an expression contains multiple
1734      // WindowExpressions with different Window Specs.
1735      // After extracting WindowExpressions, we need to construct a project list to generate
1736      // expressionsWithWindowFunctions based on extractedWindowExprBuffer.
1737      // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract
1738      // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to
1739      // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)".
1740      // Then, the projectList will be [_we0/_we1].
1741      val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]()
1742      val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
1743        // We need to use transformDown because we want to trigger
1744        // "case alias @ Alias(window: WindowExpression, _)" first.
1745        _.transformDown {
1746          case alias @ Alias(window: WindowExpression, _) =>
1747            // If a WindowExpression has an assigned alias, just use it.
1748            extractedWindowExprBuffer += alias
1749            alias.toAttribute
1750          case window: WindowExpression =>
1751            // If there is no alias assigned to the WindowExpressions. We create an
1752            // internal column.
1753            val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")()
1754            extractedWindowExprBuffer += withName
1755            withName.toAttribute
1756        }.asInstanceOf[NamedExpression]
1757      }
1758
1759      // Second, we group extractedWindowExprBuffer based on their Partition and Order Specs.
1760      val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr =>
1761        val distinctWindowSpec = expr.collect {
1762          case window: WindowExpression => window.windowSpec
1763        }.distinct
1764
1765        // We do a final check and see if we only have a single Window Spec defined in an
1766        // expressions.
1767        if (distinctWindowSpec.isEmpty) {
1768          failAnalysis(s"$expr does not have any WindowExpression.")
1769        } else if (distinctWindowSpec.length > 1) {
1770          // newExpressionsWithWindowFunctions only have expressions with a single
1771          // WindowExpression. If we reach here, we have a bug.
1772          failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." +
1773            s"Please file a bug report with this error message, stack trace, and the query.")
1774        } else {
1775          val spec = distinctWindowSpec.head
1776          (spec.partitionSpec, spec.orderSpec)
1777        }
1778      }.toSeq
1779
1780      // Third, we aggregate them by adding each Window operator for each Window Spec and then
1781      // setting this to the child of the next Window operator.
1782      val windowOps =
1783        groupedWindowExpressions.foldLeft(child) {
1784          case (last, ((partitionSpec, orderSpec), windowExpressions)) =>
1785            Window(windowExpressions, partitionSpec, orderSpec, last)
1786        }
1787
1788      // Finally, we create a Project to output windowOps's output
1789      // newExpressionsWithWindowFunctions.
1790      Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps)
1791    } // end of addWindow
1792
1793    // We have to use transformDown at here to make sure the rule of
1794    // "Aggregate with Having clause" will be triggered.
1795    def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
1796
1797      // Aggregate with Having clause. This rule works with an unresolved Aggregate because
1798      // a resolved Aggregate will not have Window Functions.
1799      case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
1800        if child.resolved &&
1801           hasWindowFunction(aggregateExprs) &&
1802           a.expressions.forall(_.resolved) =>
1803        val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
1804        // Create an Aggregate operator to evaluate aggregation functions.
1805        val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
1806        // Add a Filter operator for conditions in the Having clause.
1807        val withFilter = Filter(condition, withAggregate)
1808        val withWindow = addWindow(windowExpressions, withFilter)
1809
1810        // Finally, generate output columns according to the original projectList.
1811        val finalProjectList = aggregateExprs.map(_.toAttribute)
1812        Project(finalProjectList, withWindow)
1813
1814      case p: LogicalPlan if !p.childrenResolved => p
1815
1816      // Aggregate without Having clause.
1817      case a @ Aggregate(groupingExprs, aggregateExprs, child)
1818        if hasWindowFunction(aggregateExprs) &&
1819           a.expressions.forall(_.resolved) =>
1820        val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
1821        // Create an Aggregate operator to evaluate aggregation functions.
1822        val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
1823        // Add Window operators.
1824        val withWindow = addWindow(windowExpressions, withAggregate)
1825
1826        // Finally, generate output columns according to the original projectList.
1827        val finalProjectList = aggregateExprs.map(_.toAttribute)
1828        Project(finalProjectList, withWindow)
1829
1830      // We only extract Window Expressions after all expressions of the Project
1831      // have been resolved.
1832      case p @ Project(projectList, child)
1833        if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) =>
1834        val (windowExpressions, regularExpressions) = extract(projectList)
1835        // We add a project to get all needed expressions for window expressions from the child
1836        // of the original Project operator.
1837        val withProject = Project(regularExpressions, child)
1838        // Add Window operators.
1839        val withWindow = addWindow(windowExpressions, withProject)
1840
1841        // Finally, generate output columns according to the original projectList.
1842        val finalProjectList = projectList.map(_.toAttribute)
1843        Project(finalProjectList, withWindow)
1844    }
1845  }
1846
1847  /**
1848   * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter,
1849   * put them into an inner Project and finally project them away at the outer Project.
1850   */
1851  object PullOutNondeterministic extends Rule[LogicalPlan] {
1852    override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1853      case p if !p.resolved => p // Skip unresolved nodes.
1854      case p: Project => p
1855      case f: Filter => f
1856
1857      case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) =>
1858        val nondeterToAttr = getNondeterToAttr(a.groupingExpressions)
1859        val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child)
1860        a.transformExpressions { case e =>
1861          nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
1862        }.copy(child = newChild)
1863
1864      // todo: It's hard to write a general rule to pull out nondeterministic expressions
1865      // from LogicalPlan, currently we only do it for UnaryNode which has same output
1866      // schema with its child.
1867      case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
1868        val nondeterToAttr = getNondeterToAttr(p.expressions)
1869        val newPlan = p.transformExpressions { case e =>
1870          nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
1871        }
1872        val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child)
1873        Project(p.output, newPlan.withNewChildren(newChild :: Nil))
1874    }
1875
1876    private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = {
1877      exprs.filterNot(_.deterministic).flatMap { expr =>
1878        val leafNondeterministic = expr.collect { case n: Nondeterministic => n }
1879        leafNondeterministic.distinct.map { e =>
1880          val ne = e match {
1881            case n: NamedExpression => n
1882            case _ => Alias(e, "_nondeterministic")(isGenerated = true)
1883          }
1884          e -> ne
1885        }
1886      }.toMap
1887    }
1888  }
1889
1890  /**
1891   * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
1892   * null check.  When user defines a UDF with primitive parameters, there is no way to tell if the
1893   * primitive parameter is null or not, so here we assume the primitive input is null-propagatable
1894   * and we should return null if the input is null.
1895   */
1896  object HandleNullInputsForUDF extends Rule[LogicalPlan] {
1897    override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1898      case p if !p.resolved => p // Skip unresolved nodes.
1899
1900      case p => p transformExpressionsUp {
1901
1902        case udf @ ScalaUDF(func, _, inputs, _, _) =>
1903          val parameterTypes = ScalaReflection.getParameterTypes(func)
1904          assert(parameterTypes.length == inputs.length)
1905
1906          val inputsNullCheck = parameterTypes.zip(inputs)
1907            // TODO: skip null handling for not-nullable primitive inputs after we can completely
1908            // trust the `nullable` information.
1909            // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
1910            .filter { case (cls, _) => cls.isPrimitive }
1911            .map { case (_, expr) => IsNull(expr) }
1912            .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
1913          inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)
1914      }
1915    }
1916  }
1917
1918  /**
1919   * Check and add proper window frames for all window functions.
1920   */
1921  object ResolveWindowFrame extends Rule[LogicalPlan] {
1922    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1923      case logical: LogicalPlan => logical transformExpressions {
1924        case WindowExpression(wf: WindowFunction,
1925        WindowSpecDefinition(_, _, f: SpecifiedWindowFrame))
1926          if wf.frame != UnspecifiedFrame && wf.frame != f =>
1927          failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}")
1928        case WindowExpression(wf: WindowFunction,
1929        s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
1930          if wf.frame != UnspecifiedFrame =>
1931          WindowExpression(wf, s.copy(frameSpecification = wf.frame))
1932        case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
1933          if e.resolved =>
1934          val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true)
1935          we.copy(windowSpec = s.copy(frameSpecification = frame))
1936      }
1937    }
1938  }
1939
1940  /**
1941   * Check and add order to [[AggregateWindowFunction]]s.
1942   */
1943  object ResolveWindowOrder extends Rule[LogicalPlan] {
1944    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1945      case logical: LogicalPlan => logical transformExpressions {
1946        case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty =>
1947          failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " +
1948            s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " +
1949            s"ORDER BY window_ordering) from table")
1950        case WindowExpression(rank: RankLike, spec) if spec.resolved =>
1951          val order = spec.orderSpec.map(_.child)
1952          WindowExpression(rank.withOrder(order), spec)
1953      }
1954    }
1955  }
1956
1957  /**
1958   * Removes natural or using joins by calculating output columns based on output from two sides,
1959   * Then apply a Project on a normal Join to eliminate natural or using join.
1960   */
1961  object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
1962    override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1963      case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
1964          if left.resolved && right.resolved && j.duplicateResolved =>
1965        commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
1966      case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
1967        // find common column names from both sides
1968        val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
1969        commonNaturalJoinProcessing(left, right, joinType, joinNames, condition)
1970    }
1971  }
1972
1973  private def commonNaturalJoinProcessing(
1974      left: LogicalPlan,
1975      right: LogicalPlan,
1976      joinType: JoinType,
1977      joinNames: Seq[String],
1978      condition: Option[Expression]) = {
1979    val leftKeys = joinNames.map { keyName =>
1980      left.output.find(attr => resolver(attr.name, keyName)).getOrElse {
1981        throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " +
1982          s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]")
1983      }
1984    }
1985    val rightKeys = joinNames.map { keyName =>
1986      right.output.find(attr => resolver(attr.name, keyName)).getOrElse {
1987        throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " +
1988          s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]")
1989      }
1990    }
1991    val joinPairs = leftKeys.zip(rightKeys)
1992
1993    val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)
1994
1995    // columns not in joinPairs
1996    val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
1997    val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
1998
1999    // the output list looks like: join keys, columns from left, columns from right
2000    val projectList = joinType match {
2001      case LeftOuter =>
2002        leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
2003      case LeftExistence(_) =>
2004        leftKeys ++ lUniqueOutput
2005      case RightOuter =>
2006        rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
2007      case FullOuter =>
2008        // in full outer join, joinCols should be non-null if there is.
2009        val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
2010        joinedCols ++
2011          lUniqueOutput.map(_.withNullability(true)) ++
2012          rUniqueOutput.map(_.withNullability(true))
2013      case _ : InnerLike =>
2014        leftKeys ++ lUniqueOutput ++ rUniqueOutput
2015      case _ =>
2016        sys.error("Unsupported natural join type " + joinType)
2017    }
2018    // use Project to trim unnecessary fields
2019    Project(projectList, Join(left, right, joinType, newCondition))
2020  }
2021
2022  /**
2023   * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
2024   * to the given input attributes.
2025   */
2026  object ResolveDeserializer extends Rule[LogicalPlan] {
2027    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2028      case p if !p.childrenResolved => p
2029      case p if p.resolved => p
2030
2031      case p => p transformExpressions {
2032        case UnresolvedDeserializer(deserializer, inputAttributes) =>
2033          val inputs = if (inputAttributes.isEmpty) {
2034            p.children.flatMap(_.output)
2035          } else {
2036            inputAttributes
2037          }
2038
2039          validateTopLevelTupleFields(deserializer, inputs)
2040          val resolved = resolveExpression(
2041            deserializer, LocalRelation(inputs), throws = true)
2042          validateNestedTupleFields(resolved)
2043          resolved
2044      }
2045    }
2046
2047    private def fail(schema: StructType, maxOrdinal: Int): Unit = {
2048      throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " +
2049        "but failed as the number of fields does not line up.")
2050    }
2051
2052    /**
2053     * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column
2054     * by position.  However, the actual number of columns may be different from the number of Tuple
2055     * fields.  This method is used to check the number of columns and fields, and throw an
2056     * exception if they do not match.
2057     */
2058    private def validateTopLevelTupleFields(
2059        deserializer: Expression, inputs: Seq[Attribute]): Unit = {
2060      val ordinals = deserializer.collect {
2061        case GetColumnByOrdinal(ordinal, _) => ordinal
2062      }.distinct.sorted
2063
2064      if (ordinals.nonEmpty && ordinals != inputs.indices) {
2065        fail(inputs.toStructType, ordinals.last)
2066      }
2067    }
2068
2069    /**
2070     * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field
2071     * by position.  However, the actual number of struct fields may be different from the number
2072     * of nested Tuple fields.  This method is used to check the number of struct fields and nested
2073     * Tuple fields, and throw an exception if they do not match.
2074     */
2075    private def validateNestedTupleFields(deserializer: Expression): Unit = {
2076      val structChildToOrdinals = deserializer
2077        // There are 2 kinds of `GetStructField`:
2078        //   1. resolved from `UnresolvedExtractValue`, and it will have a `name` property.
2079        //   2. created when we build deserializer expression for nested tuple, no `name` property.
2080        // Here we want to validate the ordinals of nested tuple, so we should only catch
2081        // `GetStructField` without the name property.
2082        .collect { case g: GetStructField if g.name.isEmpty => g }
2083        .groupBy(_.child)
2084        .mapValues(_.map(_.ordinal).distinct.sorted)
2085
2086      structChildToOrdinals.foreach { case (expr, ordinals) =>
2087        val schema = expr.dataType.asInstanceOf[StructType]
2088        if (ordinals != schema.indices) {
2089          fail(schema, ordinals.last)
2090        }
2091      }
2092    }
2093  }
2094
2095  /**
2096   * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
2097   * constructed is an inner class.
2098   */
2099  object ResolveNewInstance extends Rule[LogicalPlan] {
2100    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2101      case p if !p.childrenResolved => p
2102      case p if p.resolved => p
2103
2104      case p => p transformExpressions {
2105        case n: NewInstance if n.childrenResolved && !n.resolved =>
2106          val outer = OuterScopes.getOuterScope(n.cls)
2107          if (outer == null) {
2108            throw new AnalysisException(
2109              s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
2110                "access to the scope that this class was defined in.\n" +
2111                "Try moving this class out of its parent class.")
2112          }
2113          n.copy(outerPointer = Some(outer))
2114      }
2115    }
2116  }
2117
2118  /**
2119   * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
2120   */
2121  object ResolveUpCast extends Rule[LogicalPlan] {
2122    private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
2123      throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
2124        s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
2125        "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
2126        "You can either add an explicit cast to the input data or choose a higher precision " +
2127        "type of the field in the target object")
2128    }
2129
2130    private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
2131      val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
2132      val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
2133      toPrecedence > 0 && fromPrecedence > toPrecedence
2134    }
2135
2136    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2137      case p if !p.childrenResolved => p
2138      case p if p.resolved => p
2139
2140      case p => p transformExpressions {
2141        case u @ UpCast(child, _, _) if !child.resolved => u
2142
2143        case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
2144          case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
2145            fail(child, to, walkedTypePath)
2146          case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
2147            fail(child, to, walkedTypePath)
2148          case (from, to) if illegalNumericPrecedence(from, to) =>
2149            fail(child, to, walkedTypePath)
2150          case (TimestampType, DateType) =>
2151            fail(child, DateType, walkedTypePath)
2152          case (StringType, to: NumericType) =>
2153            fail(child, to, walkedTypePath)
2154          case _ => Cast(child, dataType.asNullable)
2155        }
2156      }
2157    }
2158  }
2159}
2160
2161/**
2162 * Removes [[SubqueryAlias]] operators from the plan. Subqueries are only required to provide
2163 * scoping information for attributes and can be removed once analysis is complete.
2164 */
2165object EliminateSubqueryAliases extends Rule[LogicalPlan] {
2166  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
2167    case SubqueryAlias(_, child, _) => child
2168  }
2169}
2170
2171/**
2172 * Removes [[Union]] operators from the plan if it just has one child.
2173 */
2174object EliminateUnions extends Rule[LogicalPlan] {
2175  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
2176    case Union(children) if children.size == 1 => children.head
2177  }
2178}
2179
2180/**
2181 * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
2182 * expression in Project(project list) or Aggregate(aggregate expressions) or
2183 * Window(window expressions).
2184 */
2185object CleanupAliases extends Rule[LogicalPlan] {
2186  private def trimAliases(e: Expression): Expression = {
2187    e.transformDown {
2188      case Alias(child, _) => child
2189    }
2190  }
2191
2192  def trimNonTopLevelAliases(e: Expression): Expression = e match {
2193    case a: Alias =>
2194      a.withNewChildren(trimAliases(a.child) :: Nil)
2195    case other => trimAliases(other)
2196  }
2197
2198  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2199    case Project(projectList, child) =>
2200      val cleanedProjectList =
2201        projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
2202      Project(cleanedProjectList, child)
2203
2204    case Aggregate(grouping, aggs, child) =>
2205      val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
2206      Aggregate(grouping.map(trimAliases), cleanedAggs, child)
2207
2208    case w @ Window(windowExprs, partitionSpec, orderSpec, child) =>
2209      val cleanedWindowExprs =
2210        windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
2211      Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
2212        orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
2213
2214    // Operators that operate on objects should only have expressions from encoders, which should
2215    // never have extra aliases.
2216    case o: ObjectConsumer => o
2217    case o: ObjectProducer => o
2218    case a: AppendColumns => a
2219
2220    case other =>
2221      other transformExpressionsDown {
2222        case Alias(child, _) => child
2223      }
2224  }
2225}
2226
2227/**
2228 * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
2229 * figure out how many windows a time column can map to, we over-estimate the number of windows and
2230 * filter out the rows where the time column is not inside the time window.
2231 */
2232object TimeWindowing extends Rule[LogicalPlan] {
2233  import org.apache.spark.sql.catalyst.dsl.expressions._
2234
2235  private final val WINDOW_START = "start"
2236  private final val WINDOW_END = "end"
2237
2238  /**
2239   * Generates the logical plan for generating window ranges on a timestamp column. Without
2240   * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many
2241   * window ranges a timestamp will map to given all possible combinations of a window duration,
2242   * slide duration and start time (offset). Therefore, we express and over-estimate the number of
2243   * windows there may be, and filter the valid windows. We use last Project operator to group
2244   * the window columns into a struct so they can be accessed as `window.start` and `window.end`.
2245   *
2246   * The windows are calculated as below:
2247   * maxNumOverlapping <- ceil(windowDuration / slideDuration)
2248   * for (i <- 0 until maxNumOverlapping)
2249   *   windowId <- ceil((timestamp - startTime) / slideDuration)
2250   *   windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime
2251   *   windowEnd <- windowStart + windowDuration
2252   *   return windowStart, windowEnd
2253   *
2254   * This behaves as follows for the given parameters for the time: 12:05. The valid windows are
2255   * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the
2256   * Filter operator.
2257   * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m
2258   *     11:55 - 12:07 +                      11:52 - 12:04 x
2259   *     12:00 - 12:12 +                      11:57 - 12:09 +
2260   *     12:05 - 12:17 +                      12:02 - 12:14 +
2261   *
2262   * @param plan The logical plan
2263   * @return the logical plan that will generate the time windows using the Expand operator, with
2264   *         the Filter operator for correctness and Project for usability.
2265   */
2266  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2267    case p: LogicalPlan if p.children.size == 1 =>
2268      val child = p.children.head
2269      val windowExpressions =
2270        p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct.
2271
2272      // Only support a single window expression for now
2273      if (windowExpressions.size == 1 &&
2274          windowExpressions.head.timeColumn.resolved &&
2275          windowExpressions.head.checkInputDataTypes().isSuccess) {
2276        val window = windowExpressions.head
2277
2278        val metadata = window.timeColumn match {
2279          case a: Attribute => a.metadata
2280          case _ => Metadata.empty
2281        }
2282        val windowAttr =
2283          AttributeReference("window", window.dataType, metadata = metadata)()
2284
2285        val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
2286        val windows = Seq.tabulate(maxNumOverlapping + 1) { i =>
2287          val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) /
2288            window.slideDuration)
2289          val windowStart = (windowId + i - maxNumOverlapping) *
2290              window.slideDuration + window.startTime
2291          val windowEnd = windowStart + window.windowDuration
2292
2293          CreateNamedStruct(
2294            Literal(WINDOW_START) :: windowStart ::
2295            Literal(WINDOW_END) :: windowEnd :: Nil)
2296        }
2297
2298        val projections = windows.map(_ +: p.children.head.output)
2299
2300        val filterExpr =
2301          window.timeColumn >= windowAttr.getField(WINDOW_START) &&
2302          window.timeColumn < windowAttr.getField(WINDOW_END)
2303
2304        val expandedPlan =
2305          Filter(filterExpr,
2306            Expand(projections, windowAttr +: child.output, child))
2307
2308        val substitutedPlan = p transformExpressions {
2309          case t: TimeWindow => windowAttr
2310        }
2311
2312        substitutedPlan.withNewChildren(expandedPlan :: Nil)
2313      } else if (windowExpressions.size > 1) {
2314        p.failAnalysis("Multiple time window expressions would result in a cartesian product " +
2315          "of rows, therefore they are not currently not supported.")
2316      } else {
2317        p // Return unchanged. Analyzer will throw exception later
2318      }
2319  }
2320}
2321
2322/**
2323 * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s.
2324 */
2325object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
2326  override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
2327    case e: CreateNamedStruct if !e.resolved =>
2328      val children = e.children.grouped(2).flatMap {
2329        case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
2330          Seq(Literal(e.name), e)
2331        case kv =>
2332          kv
2333      }
2334      CreateNamedStruct(children.toList)
2335  }
2336}
2337