1/*
2* Licensed to the Apache Software Foundation (ASF) under one or more
3* contributor license agreements.  See the NOTICE file distributed with
4* this work for additional information regarding copyright ownership.
5* The ASF licenses this file to You under the Apache License, Version 2.0
6* (the "License"); you may not use this file except in compliance with
7* the License.  You may obtain a copy of the License at
8*
9*    http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*/
17
18package org.apache.spark.sql.execution.python
19
20import scala.collection.mutable
21import scala.collection.mutable.ArrayBuffer
22
23import org.apache.spark.sql.catalyst.expressions._
24import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
26import org.apache.spark.sql.catalyst.rules.Rule
27import org.apache.spark.sql.execution
28import org.apache.spark.sql.execution.SparkPlan
29
30
31/**
32 * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
33 * grouping key, evaluate them after aggregate.
34 */
35object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
36
37  /**
38   * Returns whether the expression could only be evaluated within aggregate.
39   */
40  private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
41    e.isInstanceOf[AggregateExpression] ||
42      agg.groupingExpressions.exists(_.semanticEquals(e))
43  }
44
45  private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
46    expr.find {
47      e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
48    }.isDefined
49  }
50
51  private def extract(agg: Aggregate): LogicalPlan = {
52    val projList = new ArrayBuffer[NamedExpression]()
53    val aggExpr = new ArrayBuffer[NamedExpression]()
54    agg.aggregateExpressions.foreach { expr =>
55      if (hasPythonUdfOverAggregate(expr, agg)) {
56        // Python UDF can only be evaluated after aggregate
57        val newE = expr transformDown {
58          case e: Expression if belongAggregate(e, agg) =>
59            val alias = e match {
60              case a: NamedExpression => a
61              case o => Alias(e, "agg")()
62            }
63            aggExpr += alias
64            alias.toAttribute
65        }
66        projList += newE.asInstanceOf[NamedExpression]
67      } else {
68        aggExpr += expr
69        projList += expr.toAttribute
70      }
71    }
72    // There is no Python UDF over aggregate expression
73    Project(projList, agg.copy(aggregateExpressions = aggExpr))
74  }
75
76  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
77    case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
78      extract(agg)
79  }
80}
81
82
83/**
84 * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
85 * alone in a batch.
86 *
87 * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs
88 * or all the children could be evaluated in JVM).
89 *
90 * This has the limitation that the input to the Python UDF is not allowed include attributes from
91 * multiple child operators.
92 */
93object ExtractPythonUDFs extends Rule[SparkPlan] {
94
95  private def hasPythonUDF(e: Expression): Boolean = {
96    e.find(_.isInstanceOf[PythonUDF]).isDefined
97  }
98
99  private def canEvaluateInPython(e: PythonUDF): Boolean = {
100    e.children match {
101      // single PythonUDF child could be chained and evaluated in Python
102      case Seq(u: PythonUDF) => canEvaluateInPython(u)
103      // Python UDF can't be evaluated directly in JVM
104      case children => !children.exists(hasPythonUDF)
105    }
106  }
107
108  private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
109    case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf)
110    case e => e.children.flatMap(collectEvaluatableUDF)
111  }
112
113  def apply(plan: SparkPlan): SparkPlan = plan transformUp {
114    case plan: SparkPlan => extract(plan)
115  }
116
117  /**
118   * Extract all the PythonUDFs from the current operator and evaluate them before the operator.
119   */
120  private def extract(plan: SparkPlan): SparkPlan = {
121    val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
122      // ignore the PythonUDF that come from second/third aggregate, which is not used
123      .filter(udf => udf.references.subsetOf(plan.inputSet))
124    if (udfs.isEmpty) {
125      // If there aren't any, we are done.
126      plan
127    } else {
128      val attributeMap = mutable.HashMap[PythonUDF, Expression]()
129      // Rewrite the child that has the input required for the UDF
130      val newChildren = plan.children.map { child =>
131        // Pick the UDF we are going to evaluate
132        val validUdfs = udfs.filter { case udf =>
133          // Check to make sure that the UDF can be evaluated with only the input of this child.
134          udf.references.subsetOf(child.outputSet)
135        }.toArray  // Turn it into an array since iterators cannot be serialized in Scala 2.10
136        if (validUdfs.nonEmpty) {
137          val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
138            AttributeReference(s"pythonUDF$i", u.dataType)()
139          }
140          val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child)
141          attributeMap ++= validUdfs.zip(resultAttrs)
142          evaluation
143        } else {
144          child
145        }
146      }
147      // Other cases are disallowed as they are ambiguous or would require a cartesian
148      // product.
149      udfs.filterNot(attributeMap.contains).foreach { udf =>
150        sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
151      }
152
153      val rewritten = plan.withNewChildren(newChildren).transformExpressions {
154        case p: PythonUDF if attributeMap.contains(p) =>
155          attributeMap(p)
156      }
157
158      // extract remaining python UDFs recursively
159      val newPlan = extract(rewritten)
160      if (newPlan.output != plan.output) {
161        // Trim away the new UDF value if it was only used for filtering or something.
162        execution.ProjectExec(plan.output, newPlan)
163      } else {
164        newPlan
165      }
166    }
167  }
168}
169