1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.execution
19
20import scala.language.existentials
21
22import org.apache.spark.api.java.function.MapFunction
23import org.apache.spark.api.r._
24import org.apache.spark.broadcast.Broadcast
25import org.apache.spark.rdd.RDD
26import org.apache.spark.sql.api.r.SQLUtils._
27import org.apache.spark.sql.catalyst.InternalRow
28import org.apache.spark.sql.catalyst.expressions._
29import org.apache.spark.sql.catalyst.expressions.codegen._
30import org.apache.spark.sql.catalyst.expressions.objects.Invoke
31import org.apache.spark.sql.catalyst.plans.physical._
32import org.apache.spark.sql.Row
33import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
34
35
36/**
37 * Physical version of `ObjectProducer`.
38 */
39trait ObjectProducerExec extends SparkPlan {
40  // The attribute that reference to the single object field this operator outputs.
41  protected def outputObjAttr: Attribute
42
43  override def output: Seq[Attribute] = outputObjAttr :: Nil
44
45  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
46
47  def outputObjectType: DataType = outputObjAttr.dataType
48}
49
50/**
51 * Physical version of `ObjectConsumer`.
52 */
53trait ObjectConsumerExec extends UnaryExecNode {
54  assert(child.output.length == 1)
55
56  // This operator always need all columns of its child, even it doesn't reference to.
57  override def references: AttributeSet = child.outputSet
58
59  def inputObjectType: DataType = child.output.head.dataType
60}
61
62/**
63 * Takes the input row from child and turns it into object using the given deserializer expression.
64 * The output of this operator is a single-field safe row containing the deserialized object.
65 */
66case class DeserializeToObjectExec(
67    deserializer: Expression,
68    outputObjAttr: Attribute,
69    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with CodegenSupport {
70
71  override def outputPartitioning: Partitioning = child.outputPartitioning
72
73  override def inputRDDs(): Seq[RDD[InternalRow]] = {
74    child.asInstanceOf[CodegenSupport].inputRDDs()
75  }
76
77  protected override def doProduce(ctx: CodegenContext): String = {
78    child.asInstanceOf[CodegenSupport].produce(ctx, this)
79  }
80
81  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
82    val bound = ExpressionCanonicalizer.execute(
83      BindReferences.bindReference(deserializer, child.output))
84    ctx.currentVars = input
85    val resultVars = bound.genCode(ctx) :: Nil
86    consume(ctx, resultVars)
87  }
88
89  override protected def doExecute(): RDD[InternalRow] = {
90    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
91      val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
92      projection.initialize(index)
93      iter.map(projection)
94    }
95  }
96}
97
98/**
99 * Takes the input object from child and turns in into unsafe row using the given serializer
100 * expression.  The output of its child must be a single-field row containing the input object.
101 */
102case class SerializeFromObjectExec(
103    serializer: Seq[NamedExpression],
104    child: SparkPlan) extends ObjectConsumerExec with CodegenSupport {
105
106  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
107
108  override def outputPartitioning: Partitioning = child.outputPartitioning
109
110  override def inputRDDs(): Seq[RDD[InternalRow]] = {
111    child.asInstanceOf[CodegenSupport].inputRDDs()
112  }
113
114  protected override def doProduce(ctx: CodegenContext): String = {
115    child.asInstanceOf[CodegenSupport].produce(ctx, this)
116  }
117
118  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
119    val bound = serializer.map { expr =>
120      ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
121    }
122    ctx.currentVars = input
123    val resultVars = bound.map(_.genCode(ctx))
124    consume(ctx, resultVars)
125  }
126
127  override protected def doExecute(): RDD[InternalRow] = {
128    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
129      val projection = UnsafeProjection.create(serializer)
130      projection.initialize(index)
131      iter.map(projection)
132    }
133  }
134}
135
136/**
137 * Helper functions for physical operators that work with user defined objects.
138 */
139object ObjectOperator {
140  def deserializeRowToObject(
141      deserializer: Expression,
142      inputSchema: Seq[Attribute]): InternalRow => Any = {
143    val proj = GenerateSafeProjection.generate(deserializer :: Nil, inputSchema)
144    (i: InternalRow) => proj(i).get(0, deserializer.dataType)
145  }
146
147  def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = {
148    val proj = GenerateUnsafeProjection.generate(serializer)
149    val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head
150    val objRow = new SpecificInternalRow(objType :: Nil)
151    (o: Any) => {
152      objRow(0) = o
153      proj(objRow)
154    }
155  }
156
157  def wrapObjectToRow(objType: DataType): Any => InternalRow = {
158    val outputRow = new SpecificInternalRow(objType :: Nil)
159    (o: Any) => {
160      outputRow(0) = o
161      outputRow
162    }
163  }
164
165  def unwrapObjectFromRow(objType: DataType): InternalRow => Any = {
166    (i: InternalRow) => i.get(0, objType)
167  }
168}
169
170/**
171 * Applies the given function to input object iterator.
172 * The output of its child must be a single-field row containing the input object.
173 */
174case class MapPartitionsExec(
175    func: Iterator[Any] => Iterator[Any],
176    outputObjAttr: Attribute,
177    child: SparkPlan)
178  extends ObjectConsumerExec with ObjectProducerExec {
179
180  override def outputPartitioning: Partitioning = child.outputPartitioning
181
182  override protected def doExecute(): RDD[InternalRow] = {
183    child.execute().mapPartitionsInternal { iter =>
184      val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
185      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
186      func(iter.map(getObject)).map(outputObject)
187    }
188  }
189}
190
191/**
192 * Applies the given function to each input object.
193 * The output of its child must be a single-field row containing the input object.
194 *
195 * This operator is kind of a safe version of [[ProjectExec]], as its output is custom object,
196 * we need to use safe row to contain it.
197 */
198case class MapElementsExec(
199    func: AnyRef,
200    outputObjAttr: Attribute,
201    child: SparkPlan)
202  extends ObjectConsumerExec with ObjectProducerExec with CodegenSupport {
203
204  override def inputRDDs(): Seq[RDD[InternalRow]] = {
205    child.asInstanceOf[CodegenSupport].inputRDDs()
206  }
207
208  protected override def doProduce(ctx: CodegenContext): String = {
209    child.asInstanceOf[CodegenSupport].produce(ctx, this)
210  }
211
212  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
213    val (funcClass, methodName) = func match {
214      case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
215      case _ => classOf[Any => Any] -> "apply"
216    }
217    val funcObj = Literal.create(func, ObjectType(funcClass))
218    val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
219
220    val bound = ExpressionCanonicalizer.execute(
221      BindReferences.bindReference(callFunc, child.output))
222    ctx.currentVars = input
223    val resultVars = bound.genCode(ctx) :: Nil
224
225    consume(ctx, resultVars)
226  }
227
228  override protected def doExecute(): RDD[InternalRow] = {
229    val callFunc: Any => Any = func match {
230      case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
231      case _ => func.asInstanceOf[Any => Any]
232    }
233
234    child.execute().mapPartitionsInternal { iter =>
235      val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
236      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
237      iter.map(row => outputObject(callFunc(getObject(row))))
238    }
239  }
240
241  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
242
243  override def outputPartitioning: Partitioning = child.outputPartitioning
244}
245
246/**
247 * Applies the given function to each input row, appending the encoded result at the end of the row.
248 */
249case class AppendColumnsExec(
250    func: Any => Any,
251    deserializer: Expression,
252    serializer: Seq[NamedExpression],
253    child: SparkPlan) extends UnaryExecNode {
254
255  override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute)
256
257  override def outputPartitioning: Partitioning = child.outputPartitioning
258
259  private def newColumnSchema = serializer.map(_.toAttribute).toStructType
260
261  override protected def doExecute(): RDD[InternalRow] = {
262    child.execute().mapPartitionsInternal { iter =>
263      val getObject = ObjectOperator.deserializeRowToObject(deserializer, child.output)
264      val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
265      val outputObject = ObjectOperator.serializeObjectToRow(serializer)
266
267      iter.map { row =>
268        val newColumns = outputObject(func(getObject(row)))
269        combiner.join(row.asInstanceOf[UnsafeRow], newColumns): InternalRow
270      }
271    }
272  }
273}
274
275/**
276 * An optimized version of [[AppendColumnsExec]], that can be executed
277 * on deserialized object directly.
278 */
279case class AppendColumnsWithObjectExec(
280    func: Any => Any,
281    inputSerializer: Seq[NamedExpression],
282    newColumnsSerializer: Seq[NamedExpression],
283    child: SparkPlan) extends ObjectConsumerExec {
284
285  override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute)
286
287  override def outputPartitioning: Partitioning = child.outputPartitioning
288
289  private def inputSchema = inputSerializer.map(_.toAttribute).toStructType
290  private def newColumnSchema = newColumnsSerializer.map(_.toAttribute).toStructType
291
292  override protected def doExecute(): RDD[InternalRow] = {
293    child.execute().mapPartitionsInternal { iter =>
294      val getChildObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
295      val outputChildObject = ObjectOperator.serializeObjectToRow(inputSerializer)
296      val outputNewColumnOjb = ObjectOperator.serializeObjectToRow(newColumnsSerializer)
297      val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema)
298
299      iter.map { row =>
300        val childObj = getChildObject(row)
301        val newColumns = outputNewColumnOjb(func(childObj))
302        combiner.join(outputChildObject(childObj), newColumns): InternalRow
303      }
304    }
305  }
306}
307
308/**
309 * Groups the input rows together and calls the function with each group and an iterator containing
310 * all elements in the group.  The result of this function is flattened before being output.
311 */
312case class MapGroupsExec(
313    func: (Any, Iterator[Any]) => TraversableOnce[Any],
314    keyDeserializer: Expression,
315    valueDeserializer: Expression,
316    groupingAttributes: Seq[Attribute],
317    dataAttributes: Seq[Attribute],
318    outputObjAttr: Attribute,
319    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
320
321  override def outputPartitioning: Partitioning = child.outputPartitioning
322
323  override def requiredChildDistribution: Seq[Distribution] =
324    ClusteredDistribution(groupingAttributes) :: Nil
325
326  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
327    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
328
329  override protected def doExecute(): RDD[InternalRow] = {
330    child.execute().mapPartitionsInternal { iter =>
331      val grouped = GroupedIterator(iter, groupingAttributes, child.output)
332
333      val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
334      val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
335      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
336
337      grouped.flatMap { case (key, rowIter) =>
338        val result = func(
339          getKey(key),
340          rowIter.map(getValue))
341        result.map(outputObject)
342      }
343    }
344  }
345}
346
347/**
348 * Groups the input rows together and calls the R function with each group and an iterator
349 * containing all elements in the group.
350 * The result of this function is flattened before being output.
351 */
352case class FlatMapGroupsInRExec(
353    func: Array[Byte],
354    packageNames: Array[Byte],
355    broadcastVars: Array[Broadcast[Object]],
356    inputSchema: StructType,
357    outputSchema: StructType,
358    keyDeserializer: Expression,
359    valueDeserializer: Expression,
360    groupingAttributes: Seq[Attribute],
361    dataAttributes: Seq[Attribute],
362    outputObjAttr: Attribute,
363    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
364
365  override def output: Seq[Attribute] = outputObjAttr :: Nil
366
367  override def outputPartitioning: Partitioning = child.outputPartitioning
368
369  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
370
371  override def requiredChildDistribution: Seq[Distribution] =
372    ClusteredDistribution(groupingAttributes) :: Nil
373
374  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
375    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
376
377  override protected def doExecute(): RDD[InternalRow] = {
378    val isSerializedRData =
379      if (outputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
380    val serializerForR = if (!isSerializedRData) {
381      SerializationFormats.ROW
382    } else {
383      SerializationFormats.BYTE
384    }
385
386    child.execute().mapPartitionsInternal { iter =>
387      val grouped = GroupedIterator(iter, groupingAttributes, child.output)
388      val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
389      val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
390      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
391      val runner = new RRunner[Array[Byte]](
392        func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
393        isDataFrame = true, colNames = inputSchema.fieldNames,
394        mode = RRunnerModes.DATAFRAME_GAPPLY)
395
396      val groupedRBytes = grouped.map { case (key, rowIter) =>
397        val deserializedIter = rowIter.map(getValue)
398        val newIter =
399          deserializedIter.asInstanceOf[Iterator[Row]].map { row => rowToRBytes(row) }
400        val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
401        (newKey, newIter)
402      }
403
404      val outputIter = runner.compute(groupedRBytes, -1)
405      if (!isSerializedRData) {
406        val result = outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
407        result.map(outputObject)
408      } else {
409        val result = outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
410        result.map(outputObject)
411      }
412    }
413  }
414}
415
416/**
417 * Co-groups the data from left and right children, and calls the function with each group and 2
418 * iterators containing all elements in the group from left and right side.
419 * The result of this function is flattened before being output.
420 */
421case class CoGroupExec(
422    func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
423    keyDeserializer: Expression,
424    leftDeserializer: Expression,
425    rightDeserializer: Expression,
426    leftGroup: Seq[Attribute],
427    rightGroup: Seq[Attribute],
428    leftAttr: Seq[Attribute],
429    rightAttr: Seq[Attribute],
430    outputObjAttr: Attribute,
431    left: SparkPlan,
432    right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
433
434  override def requiredChildDistribution: Seq[Distribution] =
435    ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
436
437  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
438    leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
439
440  override protected def doExecute(): RDD[InternalRow] = {
441    left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
442      val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
443      val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
444
445      val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup)
446      val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr)
447      val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr)
448      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
449
450      new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
451        case (key, leftResult, rightResult) =>
452          val result = func(
453            getKey(key),
454            leftResult.map(getLeft),
455            rightResult.map(getRight))
456          result.map(outputObject)
457      }
458    }
459  }
460}
461