1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.plans.logical 19 20import scala.language.existentials 21 22import org.apache.spark.api.java.function.FilterFunction 23import org.apache.spark.broadcast.Broadcast 24import org.apache.spark.sql.{Encoder, Row} 25import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer 26import org.apache.spark.sql.catalyst.encoders._ 27import org.apache.spark.sql.catalyst.expressions._ 28import org.apache.spark.sql.catalyst.expressions.objects.Invoke 29import org.apache.spark.sql.types._ 30 31object CatalystSerde { 32 def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { 33 val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) 34 DeserializeToObject(deserializer, generateObjAttr[T], child) 35 } 36 37 def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { 38 SerializeFromObject(encoderFor[T].namedExpressions, child) 39 } 40 41 def generateObjAttr[T : Encoder]: Attribute = { 42 val enc = encoderFor[T] 43 val dataType = enc.deserializer.dataType 44 val nullable = !enc.clsTag.runtimeClass.isPrimitive 45 AttributeReference("obj", dataType, nullable)() 46 } 47} 48 49/** 50 * A trait for logical operators that produces domain objects as output. 51 * The output of this operator is a single-field safe row containing the produced object. 52 */ 53trait ObjectProducer extends LogicalPlan { 54 // The attribute that reference to the single object field this operator outputs. 55 def outputObjAttr: Attribute 56 57 override def output: Seq[Attribute] = outputObjAttr :: Nil 58 59 override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) 60} 61 62/** 63 * A trait for logical operators that consumes domain objects as input. 64 * The output of its child must be a single-field row containing the input object. 65 */ 66trait ObjectConsumer extends UnaryNode { 67 assert(child.output.length == 1) 68 69 // This operator always need all columns of its child, even it doesn't reference to. 70 override def references: AttributeSet = child.outputSet 71 72 def inputObjAttr: Attribute = child.output.head 73} 74 75/** 76 * Takes the input row from child and turns it into object using the given deserializer expression. 77 */ 78case class DeserializeToObject( 79 deserializer: Expression, 80 outputObjAttr: Attribute, 81 child: LogicalPlan) extends UnaryNode with ObjectProducer 82 83/** 84 * Takes the input object from child and turns it into unsafe row using the given serializer 85 * expression. 86 */ 87case class SerializeFromObject( 88 serializer: Seq[NamedExpression], 89 child: LogicalPlan) extends ObjectConsumer { 90 91 override def output: Seq[Attribute] = serializer.map(_.toAttribute) 92} 93 94object MapPartitions { 95 def apply[T : Encoder, U : Encoder]( 96 func: Iterator[T] => Iterator[U], 97 child: LogicalPlan): LogicalPlan = { 98 val deserialized = CatalystSerde.deserialize[T](child) 99 val mapped = MapPartitions( 100 func.asInstanceOf[Iterator[Any] => Iterator[Any]], 101 CatalystSerde.generateObjAttr[U], 102 deserialized) 103 CatalystSerde.serialize[U](mapped) 104 } 105} 106 107/** 108 * A relation produced by applying `func` to each partition of the `child`. 109 */ 110case class MapPartitions( 111 func: Iterator[Any] => Iterator[Any], 112 outputObjAttr: Attribute, 113 child: LogicalPlan) extends ObjectConsumer with ObjectProducer 114 115object MapPartitionsInR { 116 def apply( 117 func: Array[Byte], 118 packageNames: Array[Byte], 119 broadcastVars: Array[Broadcast[Object]], 120 schema: StructType, 121 encoder: ExpressionEncoder[Row], 122 child: LogicalPlan): LogicalPlan = { 123 val deserialized = CatalystSerde.deserialize(child)(encoder) 124 val mapped = MapPartitionsInR( 125 func, 126 packageNames, 127 broadcastVars, 128 encoder.schema, 129 schema, 130 CatalystSerde.generateObjAttr(RowEncoder(schema)), 131 deserialized) 132 CatalystSerde.serialize(mapped)(RowEncoder(schema)) 133 } 134} 135 136/** 137 * A relation produced by applying a serialized R function `func` to each partition of the `child`. 138 * 139 */ 140case class MapPartitionsInR( 141 func: Array[Byte], 142 packageNames: Array[Byte], 143 broadcastVars: Array[Broadcast[Object]], 144 inputSchema: StructType, 145 outputSchema: StructType, 146 outputObjAttr: Attribute, 147 child: LogicalPlan) extends ObjectConsumer with ObjectProducer { 148 override lazy val schema = outputSchema 149 150 override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, 151 outputObjAttr, child) 152} 153 154object MapElements { 155 def apply[T : Encoder, U : Encoder]( 156 func: AnyRef, 157 child: LogicalPlan): LogicalPlan = { 158 val deserialized = CatalystSerde.deserialize[T](child) 159 val mapped = MapElements( 160 func, 161 implicitly[Encoder[T]].clsTag.runtimeClass, 162 implicitly[Encoder[T]].schema, 163 CatalystSerde.generateObjAttr[U], 164 deserialized) 165 CatalystSerde.serialize[U](mapped) 166 } 167} 168 169/** 170 * A relation produced by applying `func` to each element of the `child`. 171 */ 172case class MapElements( 173 func: AnyRef, 174 argumentClass: Class[_], 175 argumentSchema: StructType, 176 outputObjAttr: Attribute, 177 child: LogicalPlan) extends ObjectConsumer with ObjectProducer 178 179object TypedFilter { 180 def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { 181 TypedFilter( 182 func, 183 implicitly[Encoder[T]].clsTag.runtimeClass, 184 implicitly[Encoder[T]].schema, 185 UnresolvedDeserializer(encoderFor[T].deserializer), 186 child) 187 } 188} 189 190/** 191 * A relation produced by applying `func` to each element of the `child` and filter them by the 192 * resulting boolean value. 193 * 194 * This is logically equal to a normal [[Filter]] operator whose condition expression is decoding 195 * the input row to object and apply the given function with decoded object. However we need the 196 * encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write 197 * optimizer rules. 198 */ 199case class TypedFilter( 200 func: AnyRef, 201 argumentClass: Class[_], 202 argumentSchema: StructType, 203 deserializer: Expression, 204 child: LogicalPlan) extends UnaryNode { 205 206 override def output: Seq[Attribute] = child.output 207 208 def withObjectProducerChild(obj: LogicalPlan): Filter = { 209 assert(obj.output.length == 1) 210 Filter(typedCondition(obj.output.head), obj) 211 } 212 213 def typedCondition(input: Expression): Expression = { 214 val (funcClass, methodName) = func match { 215 case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" 216 case _ => classOf[Any => Boolean] -> "apply" 217 } 218 val funcObj = Literal.create(func, ObjectType(funcClass)) 219 Invoke(funcObj, methodName, BooleanType, input :: Nil) 220 } 221} 222 223/** Factory for constructing new `AppendColumn` nodes. */ 224object AppendColumns { 225 def apply[T : Encoder, U : Encoder]( 226 func: T => U, 227 child: LogicalPlan): AppendColumns = { 228 new AppendColumns( 229 func.asInstanceOf[Any => Any], 230 implicitly[Encoder[T]].clsTag.runtimeClass, 231 implicitly[Encoder[T]].schema, 232 UnresolvedDeserializer(encoderFor[T].deserializer), 233 encoderFor[U].namedExpressions, 234 child) 235 } 236 237 def apply[T : Encoder, U : Encoder]( 238 func: T => U, 239 inputAttributes: Seq[Attribute], 240 child: LogicalPlan): AppendColumns = { 241 new AppendColumns( 242 func.asInstanceOf[Any => Any], 243 implicitly[Encoder[T]].clsTag.runtimeClass, 244 implicitly[Encoder[T]].schema, 245 UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes), 246 encoderFor[U].namedExpressions, 247 child) 248 } 249} 250 251/** 252 * A relation produced by applying `func` to each element of the `child`, concatenating the 253 * resulting columns at the end of the input row. 254 * 255 * @param deserializer used to extract the input to `func` from an input row. 256 * @param serializer use to serialize the output of `func`. 257 */ 258case class AppendColumns( 259 func: Any => Any, 260 argumentClass: Class[_], 261 argumentSchema: StructType, 262 deserializer: Expression, 263 serializer: Seq[NamedExpression], 264 child: LogicalPlan) extends UnaryNode { 265 266 override def output: Seq[Attribute] = child.output ++ newColumns 267 268 def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) 269} 270 271/** 272 * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly. 273 */ 274case class AppendColumnsWithObject( 275 func: Any => Any, 276 childSerializer: Seq[NamedExpression], 277 newColumnsSerializer: Seq[NamedExpression], 278 child: LogicalPlan) extends ObjectConsumer { 279 280 override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute) 281} 282 283/** Factory for constructing new `MapGroups` nodes. */ 284object MapGroups { 285 def apply[K : Encoder, T : Encoder, U : Encoder]( 286 func: (K, Iterator[T]) => TraversableOnce[U], 287 groupingAttributes: Seq[Attribute], 288 dataAttributes: Seq[Attribute], 289 child: LogicalPlan): LogicalPlan = { 290 val mapped = new MapGroups( 291 func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], 292 UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), 293 UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), 294 groupingAttributes, 295 dataAttributes, 296 CatalystSerde.generateObjAttr[U], 297 child) 298 CatalystSerde.serialize[U](mapped) 299 } 300} 301 302/** 303 * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. 304 * Func is invoked with an object representation of the grouping key an iterator containing the 305 * object representation of all the rows with that key. 306 * 307 * @param keyDeserializer used to extract the key object for each group. 308 * @param valueDeserializer used to extract the items in the iterator from an input row. 309 */ 310case class MapGroups( 311 func: (Any, Iterator[Any]) => TraversableOnce[Any], 312 keyDeserializer: Expression, 313 valueDeserializer: Expression, 314 groupingAttributes: Seq[Attribute], 315 dataAttributes: Seq[Attribute], 316 outputObjAttr: Attribute, 317 child: LogicalPlan) extends UnaryNode with ObjectProducer 318 319/** Factory for constructing new `FlatMapGroupsInR` nodes. */ 320object FlatMapGroupsInR { 321 def apply( 322 func: Array[Byte], 323 packageNames: Array[Byte], 324 broadcastVars: Array[Broadcast[Object]], 325 schema: StructType, 326 keyDeserializer: Expression, 327 valueDeserializer: Expression, 328 inputSchema: StructType, 329 groupingAttributes: Seq[Attribute], 330 dataAttributes: Seq[Attribute], 331 child: LogicalPlan): LogicalPlan = { 332 val mapped = FlatMapGroupsInR( 333 func, 334 packageNames, 335 broadcastVars, 336 inputSchema, 337 schema, 338 UnresolvedDeserializer(keyDeserializer, groupingAttributes), 339 UnresolvedDeserializer(valueDeserializer, dataAttributes), 340 groupingAttributes, 341 dataAttributes, 342 CatalystSerde.generateObjAttr(RowEncoder(schema)), 343 child) 344 CatalystSerde.serialize(mapped)(RowEncoder(schema)) 345 } 346} 347 348case class FlatMapGroupsInR( 349 func: Array[Byte], 350 packageNames: Array[Byte], 351 broadcastVars: Array[Broadcast[Object]], 352 inputSchema: StructType, 353 outputSchema: StructType, 354 keyDeserializer: Expression, 355 valueDeserializer: Expression, 356 groupingAttributes: Seq[Attribute], 357 dataAttributes: Seq[Attribute], 358 outputObjAttr: Attribute, 359 child: LogicalPlan) extends UnaryNode with ObjectProducer{ 360 361 override lazy val schema = outputSchema 362 363 override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, 364 keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, 365 child) 366} 367 368/** Factory for constructing new `CoGroup` nodes. */ 369object CoGroup { 370 def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder]( 371 func: (K, Iterator[L], Iterator[R]) => TraversableOnce[OUT], 372 leftGroup: Seq[Attribute], 373 rightGroup: Seq[Attribute], 374 leftAttr: Seq[Attribute], 375 rightAttr: Seq[Attribute], 376 left: LogicalPlan, 377 right: LogicalPlan): LogicalPlan = { 378 require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) 379 380 val cogrouped = CoGroup( 381 func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], 382 // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to 383 // resolve the `keyDeserializer` based on either of them, here we pick the left one. 384 UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup), 385 UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr), 386 UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr), 387 leftGroup, 388 rightGroup, 389 leftAttr, 390 rightAttr, 391 CatalystSerde.generateObjAttr[OUT], 392 left, 393 right) 394 CatalystSerde.serialize[OUT](cogrouped) 395 } 396} 397 398/** 399 * A relation produced by applying `func` to each grouping key and associated values from left and 400 * right children. 401 */ 402case class CoGroup( 403 func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], 404 keyDeserializer: Expression, 405 leftDeserializer: Expression, 406 rightDeserializer: Expression, 407 leftGroup: Seq[Attribute], 408 rightGroup: Seq[Attribute], 409 leftAttr: Seq[Attribute], 410 rightAttr: Seq[Attribute], 411 outputObjAttr: Attribute, 412 left: LogicalPlan, 413 right: LogicalPlan) extends BinaryNode with ObjectProducer 414