1/* 2* Licensed to the Apache Software Foundation (ASF) under one or more 3* contributor license agreements. See the NOTICE file distributed with 4* this work for additional information regarding copyright ownership. 5* The ASF licenses this file to You under the Apache License, Version 2.0 6* (the "License"); you may not use this file except in compliance with 7* the License. You may obtain a copy of the License at 8* 9* http://www.apache.org/licenses/LICENSE-2.0 10* 11* Unless required by applicable law or agreed to in writing, software 12* distributed under the License is distributed on an "AS IS" BASIS, 13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14* See the License for the specific language governing permissions and 15* limitations under the License. 16*/ 17 18package org.apache.spark.sql.execution.python 19 20import scala.collection.mutable 21import scala.collection.mutable.ArrayBuffer 22 23import org.apache.spark.sql.catalyst.expressions._ 24import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression 25import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} 26import org.apache.spark.sql.catalyst.rules.Rule 27import org.apache.spark.sql.execution 28import org.apache.spark.sql.execution.SparkPlan 29 30 31/** 32 * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or 33 * grouping key, evaluate them after aggregate. 34 */ 35object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { 36 37 /** 38 * Returns whether the expression could only be evaluated within aggregate. 39 */ 40 private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { 41 e.isInstanceOf[AggregateExpression] || 42 agg.groupingExpressions.exists(_.semanticEquals(e)) 43 } 44 45 private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { 46 expr.find { 47 e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined 48 }.isDefined 49 } 50 51 private def extract(agg: Aggregate): LogicalPlan = { 52 val projList = new ArrayBuffer[NamedExpression]() 53 val aggExpr = new ArrayBuffer[NamedExpression]() 54 agg.aggregateExpressions.foreach { expr => 55 if (hasPythonUdfOverAggregate(expr, agg)) { 56 // Python UDF can only be evaluated after aggregate 57 val newE = expr transformDown { 58 case e: Expression if belongAggregate(e, agg) => 59 val alias = e match { 60 case a: NamedExpression => a 61 case o => Alias(e, "agg")() 62 } 63 aggExpr += alias 64 alias.toAttribute 65 } 66 projList += newE.asInstanceOf[NamedExpression] 67 } else { 68 aggExpr += expr 69 projList += expr.toAttribute 70 } 71 } 72 // There is no Python UDF over aggregate expression 73 Project(projList, agg.copy(aggregateExpressions = aggExpr)) 74 } 75 76 def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { 77 case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) => 78 extract(agg) 79 } 80} 81 82 83/** 84 * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated 85 * alone in a batch. 86 * 87 * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs 88 * or all the children could be evaluated in JVM). 89 * 90 * This has the limitation that the input to the Python UDF is not allowed include attributes from 91 * multiple child operators. 92 */ 93object ExtractPythonUDFs extends Rule[SparkPlan] { 94 95 private def hasPythonUDF(e: Expression): Boolean = { 96 e.find(_.isInstanceOf[PythonUDF]).isDefined 97 } 98 99 private def canEvaluateInPython(e: PythonUDF): Boolean = { 100 e.children match { 101 // single PythonUDF child could be chained and evaluated in Python 102 case Seq(u: PythonUDF) => canEvaluateInPython(u) 103 // Python UDF can't be evaluated directly in JVM 104 case children => !children.exists(hasPythonUDF) 105 } 106 } 107 108 private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { 109 case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) 110 case e => e.children.flatMap(collectEvaluatableUDF) 111 } 112 113 def apply(plan: SparkPlan): SparkPlan = plan transformUp { 114 case plan: SparkPlan => extract(plan) 115 } 116 117 /** 118 * Extract all the PythonUDFs from the current operator and evaluate them before the operator. 119 */ 120 private def extract(plan: SparkPlan): SparkPlan = { 121 val udfs = plan.expressions.flatMap(collectEvaluatableUDF) 122 // ignore the PythonUDF that come from second/third aggregate, which is not used 123 .filter(udf => udf.references.subsetOf(plan.inputSet)) 124 if (udfs.isEmpty) { 125 // If there aren't any, we are done. 126 plan 127 } else { 128 val attributeMap = mutable.HashMap[PythonUDF, Expression]() 129 // Rewrite the child that has the input required for the UDF 130 val newChildren = plan.children.map { child => 131 // Pick the UDF we are going to evaluate 132 val validUdfs = udfs.filter { case udf => 133 // Check to make sure that the UDF can be evaluated with only the input of this child. 134 udf.references.subsetOf(child.outputSet) 135 }.toArray // Turn it into an array since iterators cannot be serialized in Scala 2.10 136 if (validUdfs.nonEmpty) { 137 val resultAttrs = udfs.zipWithIndex.map { case (u, i) => 138 AttributeReference(s"pythonUDF$i", u.dataType)() 139 } 140 val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child) 141 attributeMap ++= validUdfs.zip(resultAttrs) 142 evaluation 143 } else { 144 child 145 } 146 } 147 // Other cases are disallowed as they are ambiguous or would require a cartesian 148 // product. 149 udfs.filterNot(attributeMap.contains).foreach { udf => 150 sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") 151 } 152 153 val rewritten = plan.withNewChildren(newChildren).transformExpressions { 154 case p: PythonUDF if attributeMap.contains(p) => 155 attributeMap(p) 156 } 157 158 // extract remaining python UDFs recursively 159 val newPlan = extract(rewritten) 160 if (newPlan.output != plan.output) { 161 // Trim away the new UDF value if it was only used for filtering or something. 162 execution.ProjectExec(plan.output, newPlan) 163 } else { 164 newPlan 165 } 166 } 167 } 168} 169