1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.plans 19 20import org.apache.spark.sql.catalyst.expressions._ 21import org.apache.spark.sql.catalyst.trees.TreeNode 22import org.apache.spark.sql.types.{DataType, StructType} 23 24abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { 25 self: PlanType => 26 27 def output: Seq[Attribute] 28 29 /** 30 * Extracts the relevant constraints from a given set of constraints based on the attributes that 31 * appear in the [[outputSet]]. 32 */ 33 protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { 34 constraints 35 .union(inferAdditionalConstraints(constraints)) 36 .union(constructIsNotNullConstraints(constraints)) 37 .filter(constraint => 38 constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && 39 constraint.deterministic) 40 } 41 42 /** 43 * Infers a set of `isNotNull` constraints from null intolerant expressions as well as 44 * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this 45 * returns a constraint of the form `isNotNull(a)` 46 */ 47 private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { 48 // First, we propagate constraints from the null intolerant expressions. 49 var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) 50 51 // Second, we infer additional constraints from non-nullable attributes that are part of the 52 // operator's output 53 val nonNullableAttributes = output.filterNot(_.nullable) 54 isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet 55 56 isNotNullConstraints -- constraints 57 } 58 59 /** 60 * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions 61 * of constraints. 62 */ 63 private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = 64 constraint match { 65 // When the root is IsNotNull, we can push IsNotNull through the child null intolerant 66 // expressions 67 case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) 68 // Constraints always return true for all the inputs. That means, null will never be returned. 69 // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child 70 // null intolerant expressions. 71 case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) 72 } 73 74 /** 75 * Recursively explores the expressions which are null intolerant and returns all attributes 76 * in these expressions. 77 */ 78 private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { 79 case a: Attribute => Seq(a) 80 case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) 81 case _ => Seq.empty[Attribute] 82 } 83 84 // Collect aliases from expressions, so we may avoid producing recursive constraints. 85 private lazy val aliasMap = AttributeMap( 86 (expressions ++ children.flatMap(_.expressions)).collect { 87 case a: Alias => (a.toAttribute, a.child) 88 }) 89 90 /** 91 * Infers an additional set of constraints from a given set of equality constraints. 92 * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an 93 * additional constraint of the form `b = 5`. 94 * 95 * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` 96 * as they are often useless and can lead to a non-converging set of constraints. 97 */ 98 private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { 99 val constraintClasses = generateEquivalentConstraintClasses(constraints) 100 101 var inferredConstraints = Set.empty[Expression] 102 constraints.foreach { 103 case eq @ EqualTo(l: Attribute, r: Attribute) => 104 val candidateConstraints = constraints - eq 105 inferredConstraints ++= candidateConstraints.map(_ transform { 106 case a: Attribute if a.semanticEquals(l) && 107 !isRecursiveDeduction(r, constraintClasses) => r 108 }) 109 inferredConstraints ++= candidateConstraints.map(_ transform { 110 case a: Attribute if a.semanticEquals(r) && 111 !isRecursiveDeduction(l, constraintClasses) => l 112 }) 113 case _ => // No inference 114 } 115 inferredConstraints -- constraints 116 } 117 118 /* 119 * Generate a sequence of expression sets from constraints, where each set stores an equivalence 120 * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following 121 * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal 122 * to an selected attribute. 123 */ 124 private def generateEquivalentConstraintClasses( 125 constraints: Set[Expression]): Seq[Set[Expression]] = { 126 var constraintClasses = Seq.empty[Set[Expression]] 127 constraints.foreach { 128 case eq @ EqualTo(l: Attribute, r: Attribute) => 129 // Transform [[Alias]] to its child. 130 val left = aliasMap.getOrElse(l, l) 131 val right = aliasMap.getOrElse(r, r) 132 // Get the expression set for an equivalence constraint class. 133 val leftConstraintClass = getConstraintClass(left, constraintClasses) 134 val rightConstraintClass = getConstraintClass(right, constraintClasses) 135 if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { 136 // Combine the two sets. 137 constraintClasses = constraintClasses 138 .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ 139 (leftConstraintClass ++ rightConstraintClass) 140 } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty 141 // Update equivalence class of `left` expression. 142 constraintClasses = constraintClasses 143 .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) 144 } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty 145 // Update equivalence class of `right` expression. 146 constraintClasses = constraintClasses 147 .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) 148 } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty 149 // Create new equivalence constraint class since neither expression presents 150 // in any classes. 151 constraintClasses = constraintClasses :+ Set(left, right) 152 } 153 case _ => // Skip 154 } 155 156 constraintClasses 157 } 158 159 /* 160 * Get all expressions equivalent to the selected expression. 161 */ 162 private def getConstraintClass( 163 expr: Expression, 164 constraintClasses: Seq[Set[Expression]]): Set[Expression] = 165 constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) 166 167 /* 168 * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it 169 * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. 170 * Here we first get all expressions equal to `attr` and then check whether at least one of them 171 * is a child of the referenced expression. 172 */ 173 private def isRecursiveDeduction( 174 attr: Attribute, 175 constraintClasses: Seq[Set[Expression]]): Boolean = { 176 val expr = aliasMap.getOrElse(attr, attr) 177 getConstraintClass(expr, constraintClasses).exists { e => 178 expr.children.exists(_.semanticEquals(e)) 179 } 180 } 181 182 /** 183 * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For 184 * example, if this set contains the expression `a = 2` then that expression is guaranteed to 185 * evaluate to `true` for all rows produced. 186 */ 187 lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) 188 189 /** 190 * This method can be overridden by any child class of QueryPlan to specify a set of constraints 191 * based on the given operator's constraint propagation logic. These constraints are then 192 * canonicalized and filtered automatically to contain only those attributes that appear in the 193 * [[outputSet]]. 194 * 195 * See [[Canonicalize]] for more details. 196 */ 197 protected def validConstraints: Set[Expression] = Set.empty 198 199 /** 200 * Returns the set of attributes that are output by this node. 201 */ 202 def outputSet: AttributeSet = AttributeSet(output) 203 204 /** 205 * All Attributes that appear in expressions from this operator. Note that this set does not 206 * include attributes that are implicitly referenced by being passed through to the output tuple. 207 */ 208 def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) 209 210 /** 211 * The set of all attributes that are input to this operator by its children. 212 */ 213 def inputSet: AttributeSet = 214 AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) 215 216 /** 217 * The set of all attributes that are produced by this node. 218 */ 219 def producedAttributes: AttributeSet = AttributeSet.empty 220 221 /** 222 * Attributes that are referenced by expressions but not provided by this nodes children. 223 * Subclasses should override this method if they produce attributes internally as it is used by 224 * assertions designed to prevent the construction of invalid plans. 225 */ 226 def missingInput: AttributeSet = references -- inputSet -- producedAttributes 227 228 /** 229 * Runs [[transform]] with `rule` on all expressions present in this query operator. 230 * Users should not expect a specific directionality. If a specific directionality is needed, 231 * transformExpressionsDown or transformExpressionsUp should be used. 232 * 233 * @param rule the rule to be applied to every expression in this operator. 234 */ 235 def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { 236 transformExpressionsDown(rule) 237 } 238 239 /** 240 * Runs [[transformDown]] with `rule` on all expressions present in this query operator. 241 * 242 * @param rule the rule to be applied to every expression in this operator. 243 */ 244 def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { 245 mapExpressions(_.transformDown(rule)) 246 } 247 248 /** 249 * Runs [[transformUp]] with `rule` on all expressions present in this query operator. 250 * 251 * @param rule the rule to be applied to every expression in this operator. 252 * @return 253 */ 254 def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { 255 mapExpressions(_.transformUp(rule)) 256 } 257 258 /** 259 * Apply a map function to each expression present in this query operator, and return a new 260 * query operator based on the mapped expressions. 261 */ 262 def mapExpressions(f: Expression => Expression): this.type = { 263 var changed = false 264 265 @inline def transformExpression(e: Expression): Expression = { 266 val newE = f(e) 267 if (newE.fastEquals(e)) { 268 e 269 } else { 270 changed = true 271 newE 272 } 273 } 274 275 def recursiveTransform(arg: Any): AnyRef = arg match { 276 case e: Expression => transformExpression(e) 277 case Some(e: Expression) => Some(transformExpression(e)) 278 case m: Map[_, _] => m 279 case d: DataType => d // Avoid unpacking Structs 280 case seq: Traversable[_] => seq.map(recursiveTransform) 281 case other: AnyRef => other 282 case null => null 283 } 284 285 val newArgs = mapProductIterator(recursiveTransform) 286 287 if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this 288 } 289 290 /** 291 * Returns the result of running [[transformExpressions]] on this node 292 * and all its children. 293 */ 294 def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { 295 transform { 296 case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType] 297 }.asInstanceOf[this.type] 298 } 299 300 /** Returns all of the expressions present in this query plan operator. */ 301 final def expressions: Seq[Expression] = { 302 // Recursively find all expressions from a traversable. 303 def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap { 304 case e: Expression => e :: Nil 305 case s: Traversable[_] => seqToExpressions(s) 306 case other => Nil 307 } 308 309 productIterator.flatMap { 310 case e: Expression => e :: Nil 311 case Some(e: Expression) => e :: Nil 312 case seq: Traversable[_] => seqToExpressions(seq) 313 case other => Nil 314 }.toSeq 315 } 316 317 lazy val schema: StructType = StructType.fromAttributes(output) 318 319 /** Returns the output schema in the tree format. */ 320 def schemaString: String = schema.treeString 321 322 /** Prints out the schema in the tree format */ 323 // scalastyle:off println 324 def printSchema(): Unit = println(schemaString) 325 // scalastyle:on println 326 327 /** 328 * A prefix string used when printing the plan. 329 * 330 * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. 331 */ 332 protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" 333 334 override def simpleString: String = statePrefix + super.simpleString 335 336 override def verboseString: String = simpleString 337 338 /** 339 * All the subqueries of current plan. 340 */ 341 def subqueries: Seq[PlanType] = { 342 expressions.flatMap(_.collect { 343 case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] 344 }) 345 } 346 347 override protected def innerChildren: Seq[QueryPlan[_]] = subqueries 348 349 /** 350 * Canonicalized copy of this query plan. 351 */ 352 protected lazy val canonicalized: PlanType = this 353 354 /** 355 * Returns true when the given query plan will return the same results as this query plan. 356 * 357 * Since its likely undecidable to generally determine if two given plans will produce the same 358 * results, it is okay for this function to return false, even if the results are actually 359 * the same. Such behavior will not affect correctness, only the application of performance 360 * enhancements like caching. However, it is not acceptable to return true if the results could 361 * possibly be different. 362 * 363 * By default this function performs a modified version of equality that is tolerant of cosmetic 364 * differences like attribute naming and or expression id differences. Operators that 365 * can do better should override this function. 366 */ 367 def sameResult(plan: PlanType): Boolean = { 368 val left = this.canonicalized 369 val right = plan.canonicalized 370 left.getClass == right.getClass && 371 left.children.size == right.children.size && 372 left.cleanArgs == right.cleanArgs && 373 (left.children, right.children).zipped.forall(_ sameResult _) 374 } 375 376 /** 377 * All the attributes that are used for this plan. 378 */ 379 lazy val allAttributes: AttributeSeq = children.flatMap(_.output) 380 381 protected def cleanExpression(e: Expression): Expression = e match { 382 case a: Alias => 383 // As the root of the expression, Alias will always take an arbitrary exprId, we need 384 // to erase that for equality testing. 385 val cleanedExprId = 386 Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated) 387 BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) 388 case other => 389 BindReferences.bindReference(other, allAttributes, allowFailures = true) 390 } 391 392 /** Args that have cleaned such that differences in expression id should not affect equality */ 393 protected lazy val cleanArgs: Seq[Any] = { 394 def cleanArg(arg: Any): Any = arg match { 395 // Children are checked using sameResult above. 396 case tn: TreeNode[_] if containsChild(tn) => null 397 case e: Expression => cleanExpression(e).canonicalized 398 case other => other 399 } 400 401 mapProductIterator { 402 case s: Option[_] => s.map(cleanArg) 403 case s: Seq[_] => s.map(cleanArg) 404 case m: Map[_, _] => m.mapValues(cleanArg) 405 case other => cleanArg(other) 406 }.toSeq 407 } 408} 409