1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.plans
19
20import org.apache.spark.sql.catalyst.expressions._
21import org.apache.spark.sql.catalyst.trees.TreeNode
22import org.apache.spark.sql.types.{DataType, StructType}
23
24abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
25  self: PlanType =>
26
27  def output: Seq[Attribute]
28
29  /**
30   * Extracts the relevant constraints from a given set of constraints based on the attributes that
31   * appear in the [[outputSet]].
32   */
33  protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
34    constraints
35      .union(inferAdditionalConstraints(constraints))
36      .union(constructIsNotNullConstraints(constraints))
37      .filter(constraint =>
38        constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) &&
39          constraint.deterministic)
40  }
41
42  /**
43   * Infers a set of `isNotNull` constraints from null intolerant expressions as well as
44   * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
45   * returns a constraint of the form `isNotNull(a)`
46   */
47  private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
48    // First, we propagate constraints from the null intolerant expressions.
49    var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)
50
51    // Second, we infer additional constraints from non-nullable attributes that are part of the
52    // operator's output
53    val nonNullableAttributes = output.filterNot(_.nullable)
54    isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet
55
56    isNotNullConstraints -- constraints
57  }
58
59  /**
60   * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
61   * of constraints.
62   */
63  private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
64    constraint match {
65      // When the root is IsNotNull, we can push IsNotNull through the child null intolerant
66      // expressions
67      case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_))
68      // Constraints always return true for all the inputs. That means, null will never be returned.
69      // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
70      // null intolerant expressions.
71      case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_))
72    }
73
74  /**
75   * Recursively explores the expressions which are null intolerant and returns all attributes
76   * in these expressions.
77   */
78  private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
79    case a: Attribute => Seq(a)
80    case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
81    case _ => Seq.empty[Attribute]
82  }
83
84  // Collect aliases from expressions, so we may avoid producing recursive constraints.
85  private lazy val aliasMap = AttributeMap(
86    (expressions ++ children.flatMap(_.expressions)).collect {
87      case a: Alias => (a.toAttribute, a.child)
88    })
89
90  /**
91   * Infers an additional set of constraints from a given set of equality constraints.
92   * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
93   * additional constraint of the form `b = 5`.
94   *
95   * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
96   * as they are often useless and can lead to a non-converging set of constraints.
97   */
98  private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
99    val constraintClasses = generateEquivalentConstraintClasses(constraints)
100
101    var inferredConstraints = Set.empty[Expression]
102    constraints.foreach {
103      case eq @ EqualTo(l: Attribute, r: Attribute) =>
104        val candidateConstraints = constraints - eq
105        inferredConstraints ++= candidateConstraints.map(_ transform {
106          case a: Attribute if a.semanticEquals(l) &&
107            !isRecursiveDeduction(r, constraintClasses) => r
108        })
109        inferredConstraints ++= candidateConstraints.map(_ transform {
110          case a: Attribute if a.semanticEquals(r) &&
111            !isRecursiveDeduction(l, constraintClasses) => l
112        })
113      case _ => // No inference
114    }
115    inferredConstraints -- constraints
116  }
117
118  /*
119   * Generate a sequence of expression sets from constraints, where each set stores an equivalence
120   * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
121   * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
122   * to an selected attribute.
123   */
124  private def generateEquivalentConstraintClasses(
125      constraints: Set[Expression]): Seq[Set[Expression]] = {
126    var constraintClasses = Seq.empty[Set[Expression]]
127    constraints.foreach {
128      case eq @ EqualTo(l: Attribute, r: Attribute) =>
129        // Transform [[Alias]] to its child.
130        val left = aliasMap.getOrElse(l, l)
131        val right = aliasMap.getOrElse(r, r)
132        // Get the expression set for an equivalence constraint class.
133        val leftConstraintClass = getConstraintClass(left, constraintClasses)
134        val rightConstraintClass = getConstraintClass(right, constraintClasses)
135        if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
136          // Combine the two sets.
137          constraintClasses = constraintClasses
138            .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
139            (leftConstraintClass ++ rightConstraintClass)
140        } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
141          // Update equivalence class of `left` expression.
142          constraintClasses = constraintClasses
143            .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
144        } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
145          // Update equivalence class of `right` expression.
146          constraintClasses = constraintClasses
147            .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
148        } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
149          // Create new equivalence constraint class since neither expression presents
150          // in any classes.
151          constraintClasses = constraintClasses :+ Set(left, right)
152        }
153      case _ => // Skip
154    }
155
156    constraintClasses
157  }
158
159  /*
160   * Get all expressions equivalent to the selected expression.
161   */
162  private def getConstraintClass(
163      expr: Expression,
164      constraintClasses: Seq[Set[Expression]]): Set[Expression] =
165    constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])
166
167  /*
168   *  Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
169   *  has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
170   *  Here we first get all expressions equal to `attr` and then check whether at least one of them
171   *  is a child of the referenced expression.
172   */
173  private def isRecursiveDeduction(
174      attr: Attribute,
175      constraintClasses: Seq[Set[Expression]]): Boolean = {
176    val expr = aliasMap.getOrElse(attr, attr)
177    getConstraintClass(expr, constraintClasses).exists { e =>
178      expr.children.exists(_.semanticEquals(e))
179    }
180  }
181
182  /**
183   * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
184   * example, if this set contains the expression `a = 2` then that expression is guaranteed to
185   * evaluate to `true` for all rows produced.
186   */
187  lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))
188
189  /**
190   * This method can be overridden by any child class of QueryPlan to specify a set of constraints
191   * based on the given operator's constraint propagation logic. These constraints are then
192   * canonicalized and filtered automatically to contain only those attributes that appear in the
193   * [[outputSet]].
194   *
195   * See [[Canonicalize]] for more details.
196   */
197  protected def validConstraints: Set[Expression] = Set.empty
198
199  /**
200   * Returns the set of attributes that are output by this node.
201   */
202  def outputSet: AttributeSet = AttributeSet(output)
203
204  /**
205   * All Attributes that appear in expressions from this operator.  Note that this set does not
206   * include attributes that are implicitly referenced by being passed through to the output tuple.
207   */
208  def references: AttributeSet = AttributeSet(expressions.flatMap(_.references))
209
210  /**
211   * The set of all attributes that are input to this operator by its children.
212   */
213  def inputSet: AttributeSet =
214    AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
215
216  /**
217   * The set of all attributes that are produced by this node.
218   */
219  def producedAttributes: AttributeSet = AttributeSet.empty
220
221  /**
222   * Attributes that are referenced by expressions but not provided by this nodes children.
223   * Subclasses should override this method if they produce attributes internally as it is used by
224   * assertions designed to prevent the construction of invalid plans.
225   */
226  def missingInput: AttributeSet = references -- inputSet -- producedAttributes
227
228  /**
229   * Runs [[transform]] with `rule` on all expressions present in this query operator.
230   * Users should not expect a specific directionality. If a specific directionality is needed,
231   * transformExpressionsDown or transformExpressionsUp should be used.
232   *
233   * @param rule the rule to be applied to every expression in this operator.
234   */
235  def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
236    transformExpressionsDown(rule)
237  }
238
239  /**
240   * Runs [[transformDown]] with `rule` on all expressions present in this query operator.
241   *
242   * @param rule the rule to be applied to every expression in this operator.
243   */
244  def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
245    mapExpressions(_.transformDown(rule))
246  }
247
248  /**
249   * Runs [[transformUp]] with `rule` on all expressions present in this query operator.
250   *
251   * @param rule the rule to be applied to every expression in this operator.
252   * @return
253   */
254  def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = {
255    mapExpressions(_.transformUp(rule))
256  }
257
258  /**
259   * Apply a map function to each expression present in this query operator, and return a new
260   * query operator based on the mapped expressions.
261   */
262  def mapExpressions(f: Expression => Expression): this.type = {
263    var changed = false
264
265    @inline def transformExpression(e: Expression): Expression = {
266      val newE = f(e)
267      if (newE.fastEquals(e)) {
268        e
269      } else {
270        changed = true
271        newE
272      }
273    }
274
275    def recursiveTransform(arg: Any): AnyRef = arg match {
276      case e: Expression => transformExpression(e)
277      case Some(e: Expression) => Some(transformExpression(e))
278      case m: Map[_, _] => m
279      case d: DataType => d // Avoid unpacking Structs
280      case seq: Traversable[_] => seq.map(recursiveTransform)
281      case other: AnyRef => other
282      case null => null
283    }
284
285    val newArgs = mapProductIterator(recursiveTransform)
286
287    if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
288  }
289
290  /**
291   * Returns the result of running [[transformExpressions]] on this node
292   * and all its children.
293   */
294  def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
295    transform {
296      case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType]
297    }.asInstanceOf[this.type]
298  }
299
300  /** Returns all of the expressions present in this query plan operator. */
301  final def expressions: Seq[Expression] = {
302    // Recursively find all expressions from a traversable.
303    def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap {
304      case e: Expression => e :: Nil
305      case s: Traversable[_] => seqToExpressions(s)
306      case other => Nil
307    }
308
309    productIterator.flatMap {
310      case e: Expression => e :: Nil
311      case Some(e: Expression) => e :: Nil
312      case seq: Traversable[_] => seqToExpressions(seq)
313      case other => Nil
314    }.toSeq
315  }
316
317  lazy val schema: StructType = StructType.fromAttributes(output)
318
319  /** Returns the output schema in the tree format. */
320  def schemaString: String = schema.treeString
321
322  /** Prints out the schema in the tree format */
323  // scalastyle:off println
324  def printSchema(): Unit = println(schemaString)
325  // scalastyle:on println
326
327  /**
328   * A prefix string used when printing the plan.
329   *
330   * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan.
331   */
332  protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else ""
333
334  override def simpleString: String = statePrefix + super.simpleString
335
336  override def verboseString: String = simpleString
337
338  /**
339   * All the subqueries of current plan.
340   */
341  def subqueries: Seq[PlanType] = {
342    expressions.flatMap(_.collect {
343      case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType]
344    })
345  }
346
347  override protected def innerChildren: Seq[QueryPlan[_]] = subqueries
348
349  /**
350   * Canonicalized copy of this query plan.
351   */
352  protected lazy val canonicalized: PlanType = this
353
354  /**
355   * Returns true when the given query plan will return the same results as this query plan.
356   *
357   * Since its likely undecidable to generally determine if two given plans will produce the same
358   * results, it is okay for this function to return false, even if the results are actually
359   * the same.  Such behavior will not affect correctness, only the application of performance
360   * enhancements like caching.  However, it is not acceptable to return true if the results could
361   * possibly be different.
362   *
363   * By default this function performs a modified version of equality that is tolerant of cosmetic
364   * differences like attribute naming and or expression id differences. Operators that
365   * can do better should override this function.
366   */
367  def sameResult(plan: PlanType): Boolean = {
368    val left = this.canonicalized
369    val right = plan.canonicalized
370    left.getClass == right.getClass &&
371      left.children.size == right.children.size &&
372      left.cleanArgs == right.cleanArgs &&
373      (left.children, right.children).zipped.forall(_ sameResult _)
374  }
375
376  /**
377   * All the attributes that are used for this plan.
378   */
379  lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
380
381  protected def cleanExpression(e: Expression): Expression = e match {
382    case a: Alias =>
383      // As the root of the expression, Alias will always take an arbitrary exprId, we need
384      // to erase that for equality testing.
385      val cleanedExprId =
386        Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated)
387      BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true)
388    case other =>
389      BindReferences.bindReference(other, allAttributes, allowFailures = true)
390  }
391
392  /** Args that have cleaned such that differences in expression id should not affect equality */
393  protected lazy val cleanArgs: Seq[Any] = {
394    def cleanArg(arg: Any): Any = arg match {
395      // Children are checked using sameResult above.
396      case tn: TreeNode[_] if containsChild(tn) => null
397      case e: Expression => cleanExpression(e).canonicalized
398      case other => other
399    }
400
401    mapProductIterator {
402      case s: Option[_] => s.map(cleanArg)
403      case s: Seq[_] => s.map(cleanArg)
404      case m: Map[_, _] => m.mapValues(cleanArg)
405      case other => cleanArg(other)
406    }.toSeq
407  }
408}
409