1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.plans.logical
19
20import scala.language.existentials
21
22import org.apache.spark.api.java.function.FilterFunction
23import org.apache.spark.broadcast.Broadcast
24import org.apache.spark.sql.{Encoder, Row}
25import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
26import org.apache.spark.sql.catalyst.encoders._
27import org.apache.spark.sql.catalyst.expressions._
28import org.apache.spark.sql.catalyst.expressions.objects.Invoke
29import org.apache.spark.sql.types._
30
31object CatalystSerde {
32  def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
33    val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
34    DeserializeToObject(deserializer, generateObjAttr[T], child)
35  }
36
37  def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
38    SerializeFromObject(encoderFor[T].namedExpressions, child)
39  }
40
41  def generateObjAttr[T : Encoder]: Attribute = {
42    val enc = encoderFor[T]
43    val dataType = enc.deserializer.dataType
44    val nullable = !enc.clsTag.runtimeClass.isPrimitive
45    AttributeReference("obj", dataType, nullable)()
46  }
47}
48
49/**
50 * A trait for logical operators that produces domain objects as output.
51 * The output of this operator is a single-field safe row containing the produced object.
52 */
53trait ObjectProducer extends LogicalPlan {
54  // The attribute that reference to the single object field this operator outputs.
55  def outputObjAttr: Attribute
56
57  override def output: Seq[Attribute] = outputObjAttr :: Nil
58
59  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
60}
61
62/**
63 * A trait for logical operators that consumes domain objects as input.
64 * The output of its child must be a single-field row containing the input object.
65 */
66trait ObjectConsumer extends UnaryNode {
67  assert(child.output.length == 1)
68
69  // This operator always need all columns of its child, even it doesn't reference to.
70  override def references: AttributeSet = child.outputSet
71
72  def inputObjAttr: Attribute = child.output.head
73}
74
75/**
76 * Takes the input row from child and turns it into object using the given deserializer expression.
77 */
78case class DeserializeToObject(
79    deserializer: Expression,
80    outputObjAttr: Attribute,
81    child: LogicalPlan) extends UnaryNode with ObjectProducer
82
83/**
84 * Takes the input object from child and turns it into unsafe row using the given serializer
85 * expression.
86 */
87case class SerializeFromObject(
88    serializer: Seq[NamedExpression],
89    child: LogicalPlan) extends ObjectConsumer {
90
91  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
92}
93
94object MapPartitions {
95  def apply[T : Encoder, U : Encoder](
96      func: Iterator[T] => Iterator[U],
97      child: LogicalPlan): LogicalPlan = {
98    val deserialized = CatalystSerde.deserialize[T](child)
99    val mapped = MapPartitions(
100      func.asInstanceOf[Iterator[Any] => Iterator[Any]],
101      CatalystSerde.generateObjAttr[U],
102      deserialized)
103    CatalystSerde.serialize[U](mapped)
104  }
105}
106
107/**
108 * A relation produced by applying `func` to each partition of the `child`.
109 */
110case class MapPartitions(
111    func: Iterator[Any] => Iterator[Any],
112    outputObjAttr: Attribute,
113    child: LogicalPlan) extends ObjectConsumer with ObjectProducer
114
115object MapPartitionsInR {
116  def apply(
117      func: Array[Byte],
118      packageNames: Array[Byte],
119      broadcastVars: Array[Broadcast[Object]],
120      schema: StructType,
121      encoder: ExpressionEncoder[Row],
122      child: LogicalPlan): LogicalPlan = {
123    val deserialized = CatalystSerde.deserialize(child)(encoder)
124    val mapped = MapPartitionsInR(
125      func,
126      packageNames,
127      broadcastVars,
128      encoder.schema,
129      schema,
130      CatalystSerde.generateObjAttr(RowEncoder(schema)),
131      deserialized)
132    CatalystSerde.serialize(mapped)(RowEncoder(schema))
133  }
134}
135
136/**
137 * A relation produced by applying a serialized R function `func` to each partition of the `child`.
138 *
139 */
140case class MapPartitionsInR(
141    func: Array[Byte],
142    packageNames: Array[Byte],
143    broadcastVars: Array[Broadcast[Object]],
144    inputSchema: StructType,
145    outputSchema: StructType,
146    outputObjAttr: Attribute,
147    child: LogicalPlan) extends ObjectConsumer with ObjectProducer {
148  override lazy val schema = outputSchema
149
150  override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema,
151    outputObjAttr, child)
152}
153
154object MapElements {
155  def apply[T : Encoder, U : Encoder](
156      func: AnyRef,
157      child: LogicalPlan): LogicalPlan = {
158    val deserialized = CatalystSerde.deserialize[T](child)
159    val mapped = MapElements(
160      func,
161      implicitly[Encoder[T]].clsTag.runtimeClass,
162      implicitly[Encoder[T]].schema,
163      CatalystSerde.generateObjAttr[U],
164      deserialized)
165    CatalystSerde.serialize[U](mapped)
166  }
167}
168
169/**
170 * A relation produced by applying `func` to each element of the `child`.
171 */
172case class MapElements(
173    func: AnyRef,
174    argumentClass: Class[_],
175    argumentSchema: StructType,
176    outputObjAttr: Attribute,
177    child: LogicalPlan) extends ObjectConsumer with ObjectProducer
178
179object TypedFilter {
180  def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
181    TypedFilter(
182      func,
183      implicitly[Encoder[T]].clsTag.runtimeClass,
184      implicitly[Encoder[T]].schema,
185      UnresolvedDeserializer(encoderFor[T].deserializer),
186      child)
187  }
188}
189
190/**
191 * A relation produced by applying `func` to each element of the `child` and filter them by the
192 * resulting boolean value.
193 *
194 * This is logically equal to a normal [[Filter]] operator whose condition expression is decoding
195 * the input row to object and apply the given function with decoded object. However we need the
196 * encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write
197 * optimizer rules.
198 */
199case class TypedFilter(
200    func: AnyRef,
201    argumentClass: Class[_],
202    argumentSchema: StructType,
203    deserializer: Expression,
204    child: LogicalPlan) extends UnaryNode {
205
206  override def output: Seq[Attribute] = child.output
207
208  def withObjectProducerChild(obj: LogicalPlan): Filter = {
209    assert(obj.output.length == 1)
210    Filter(typedCondition(obj.output.head), obj)
211  }
212
213  def typedCondition(input: Expression): Expression = {
214    val (funcClass, methodName) = func match {
215      case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
216      case _ => classOf[Any => Boolean] -> "apply"
217    }
218    val funcObj = Literal.create(func, ObjectType(funcClass))
219    Invoke(funcObj, methodName, BooleanType, input :: Nil)
220  }
221}
222
223/** Factory for constructing new `AppendColumn` nodes. */
224object AppendColumns {
225  def apply[T : Encoder, U : Encoder](
226      func: T => U,
227      child: LogicalPlan): AppendColumns = {
228    new AppendColumns(
229      func.asInstanceOf[Any => Any],
230      implicitly[Encoder[T]].clsTag.runtimeClass,
231      implicitly[Encoder[T]].schema,
232      UnresolvedDeserializer(encoderFor[T].deserializer),
233      encoderFor[U].namedExpressions,
234      child)
235  }
236
237  def apply[T : Encoder, U : Encoder](
238      func: T => U,
239      inputAttributes: Seq[Attribute],
240      child: LogicalPlan): AppendColumns = {
241    new AppendColumns(
242      func.asInstanceOf[Any => Any],
243      implicitly[Encoder[T]].clsTag.runtimeClass,
244      implicitly[Encoder[T]].schema,
245      UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes),
246      encoderFor[U].namedExpressions,
247      child)
248  }
249}
250
251/**
252 * A relation produced by applying `func` to each element of the `child`, concatenating the
253 * resulting columns at the end of the input row.
254 *
255 * @param deserializer used to extract the input to `func` from an input row.
256 * @param serializer use to serialize the output of `func`.
257 */
258case class AppendColumns(
259    func: Any => Any,
260    argumentClass: Class[_],
261    argumentSchema: StructType,
262    deserializer: Expression,
263    serializer: Seq[NamedExpression],
264    child: LogicalPlan) extends UnaryNode {
265
266  override def output: Seq[Attribute] = child.output ++ newColumns
267
268  def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
269}
270
271/**
272 * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly.
273 */
274case class AppendColumnsWithObject(
275    func: Any => Any,
276    childSerializer: Seq[NamedExpression],
277    newColumnsSerializer: Seq[NamedExpression],
278    child: LogicalPlan) extends ObjectConsumer {
279
280  override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute)
281}
282
283/** Factory for constructing new `MapGroups` nodes. */
284object MapGroups {
285  def apply[K : Encoder, T : Encoder, U : Encoder](
286      func: (K, Iterator[T]) => TraversableOnce[U],
287      groupingAttributes: Seq[Attribute],
288      dataAttributes: Seq[Attribute],
289      child: LogicalPlan): LogicalPlan = {
290    val mapped = new MapGroups(
291      func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
292      UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
293      UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
294      groupingAttributes,
295      dataAttributes,
296      CatalystSerde.generateObjAttr[U],
297      child)
298    CatalystSerde.serialize[U](mapped)
299  }
300}
301
302/**
303 * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
304 * Func is invoked with an object representation of the grouping key an iterator containing the
305 * object representation of all the rows with that key.
306 *
307 * @param keyDeserializer used to extract the key object for each group.
308 * @param valueDeserializer used to extract the items in the iterator from an input row.
309 */
310case class MapGroups(
311    func: (Any, Iterator[Any]) => TraversableOnce[Any],
312    keyDeserializer: Expression,
313    valueDeserializer: Expression,
314    groupingAttributes: Seq[Attribute],
315    dataAttributes: Seq[Attribute],
316    outputObjAttr: Attribute,
317    child: LogicalPlan) extends UnaryNode with ObjectProducer
318
319/** Factory for constructing new `FlatMapGroupsInR` nodes. */
320object FlatMapGroupsInR {
321  def apply(
322      func: Array[Byte],
323      packageNames: Array[Byte],
324      broadcastVars: Array[Broadcast[Object]],
325      schema: StructType,
326      keyDeserializer: Expression,
327      valueDeserializer: Expression,
328      inputSchema: StructType,
329      groupingAttributes: Seq[Attribute],
330      dataAttributes: Seq[Attribute],
331      child: LogicalPlan): LogicalPlan = {
332    val mapped = FlatMapGroupsInR(
333      func,
334      packageNames,
335      broadcastVars,
336      inputSchema,
337      schema,
338      UnresolvedDeserializer(keyDeserializer, groupingAttributes),
339      UnresolvedDeserializer(valueDeserializer, dataAttributes),
340      groupingAttributes,
341      dataAttributes,
342      CatalystSerde.generateObjAttr(RowEncoder(schema)),
343      child)
344    CatalystSerde.serialize(mapped)(RowEncoder(schema))
345  }
346}
347
348case class FlatMapGroupsInR(
349    func: Array[Byte],
350    packageNames: Array[Byte],
351    broadcastVars: Array[Broadcast[Object]],
352    inputSchema: StructType,
353    outputSchema: StructType,
354    keyDeserializer: Expression,
355    valueDeserializer: Expression,
356    groupingAttributes: Seq[Attribute],
357    dataAttributes: Seq[Attribute],
358    outputObjAttr: Attribute,
359    child: LogicalPlan) extends UnaryNode with ObjectProducer{
360
361  override lazy val schema = outputSchema
362
363  override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema,
364    keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr,
365    child)
366}
367
368/** Factory for constructing new `CoGroup` nodes. */
369object CoGroup {
370  def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder](
371      func: (K, Iterator[L], Iterator[R]) => TraversableOnce[OUT],
372      leftGroup: Seq[Attribute],
373      rightGroup: Seq[Attribute],
374      leftAttr: Seq[Attribute],
375      rightAttr: Seq[Attribute],
376      left: LogicalPlan,
377      right: LogicalPlan): LogicalPlan = {
378    require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
379
380    val cogrouped = CoGroup(
381      func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
382      // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
383      // resolve the `keyDeserializer` based on either of them, here we pick the left one.
384      UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup),
385      UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr),
386      UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr),
387      leftGroup,
388      rightGroup,
389      leftAttr,
390      rightAttr,
391      CatalystSerde.generateObjAttr[OUT],
392      left,
393      right)
394    CatalystSerde.serialize[OUT](cogrouped)
395  }
396}
397
398/**
399 * A relation produced by applying `func` to each grouping key and associated values from left and
400 * right children.
401 */
402case class CoGroup(
403    func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
404    keyDeserializer: Expression,
405    leftDeserializer: Expression,
406    rightDeserializer: Expression,
407    leftGroup: Seq[Attribute],
408    rightGroup: Seq[Attribute],
409    leftAttr: Seq[Attribute],
410    rightAttr: Seq[Attribute],
411    outputObjAttr: Attribute,
412    left: LogicalPlan,
413    right: LogicalPlan) extends BinaryNode with ObjectProducer
414