1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.execution.datasources
19
20import scala.collection.mutable.ArrayBuffer
21
22import org.apache.hadoop.fs.Path
23
24import org.apache.spark.internal.Logging
25import org.apache.spark.rdd.RDD
26import org.apache.spark.sql._
27import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier}
28import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
29import org.apache.spark.sql.catalyst.analysis._
30import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation}
31import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
32import org.apache.spark.sql.catalyst.expressions
33import org.apache.spark.sql.catalyst.expressions._
34import org.apache.spark.sql.catalyst.planning.PhysicalOperation
35import org.apache.spark.sql.catalyst.plans.logical
36import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
37import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
38import org.apache.spark.sql.catalyst.rules.Rule
39import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
40import org.apache.spark.sql.execution.command._
41import org.apache.spark.sql.sources._
42import org.apache.spark.sql.types._
43import org.apache.spark.unsafe.types.UTF8String
44
45/**
46 * Replaces generic operations with specific variants that are designed to work with Spark
47 * SQL Data Sources.
48 */
49case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
50
51  def resolver: Resolver = conf.resolver
52
53  // Visible for testing.
54  def convertStaticPartitions(
55      sourceAttributes: Seq[Attribute],
56      providedPartitions: Map[String, Option[String]],
57      targetAttributes: Seq[Attribute],
58      targetPartitionSchema: StructType): Seq[NamedExpression] = {
59
60    assert(providedPartitions.exists(_._2.isDefined))
61
62    val staticPartitions = providedPartitions.flatMap {
63      case (partKey, Some(partValue)) => (partKey, partValue) :: Nil
64      case (_, None) => Nil
65    }
66
67    // The sum of the number of static partition columns and columns provided in the SELECT
68    // clause needs to match the number of columns of the target table.
69    if (staticPartitions.size + sourceAttributes.size != targetAttributes.size) {
70      throw new AnalysisException(
71        s"The data to be inserted needs to have the same number of " +
72          s"columns as the target table: target table has ${targetAttributes.size} " +
73          s"column(s) but the inserted data has ${sourceAttributes.size + staticPartitions.size} " +
74          s"column(s), which contain ${staticPartitions.size} partition column(s) having " +
75          s"assigned constant values.")
76    }
77
78    if (providedPartitions.size != targetPartitionSchema.fields.size) {
79      throw new AnalysisException(
80        s"The data to be inserted needs to have the same number of " +
81          s"partition columns as the target table: target table " +
82          s"has ${targetPartitionSchema.fields.size} partition column(s) but the inserted " +
83          s"data has ${providedPartitions.size} partition columns specified.")
84    }
85
86    staticPartitions.foreach {
87      case (partKey, partValue) =>
88        if (!targetPartitionSchema.fields.exists(field => resolver(field.name, partKey))) {
89          throw new AnalysisException(
90            s"$partKey is not a partition column. Partition columns are " +
91              s"${targetPartitionSchema.fields.map(_.name).mkString("[", ",", "]")}")
92        }
93    }
94
95    val partitionList = targetPartitionSchema.fields.map { field =>
96      val potentialSpecs = staticPartitions.filter {
97        case (partKey, partValue) => resolver(field.name, partKey)
98      }
99      if (potentialSpecs.size == 0) {
100        None
101      } else if (potentialSpecs.size == 1) {
102        val partValue = potentialSpecs.head._2
103        Some(Alias(Cast(Literal(partValue), field.dataType), "_staticPart")())
104      } else {
105        throw new AnalysisException(
106          s"Partition column ${field.name} have multiple values specified, " +
107            s"${potentialSpecs.mkString("[", ", ", "]")}. Please only specify a single value.")
108      }
109    }
110
111    // We first drop all leading static partitions using dropWhile and check if there is
112    // any static partition appear after dynamic partitions.
113    partitionList.dropWhile(_.isDefined).collectFirst {
114      case Some(_) =>
115        throw new AnalysisException(
116          s"The ordering of partition columns is " +
117            s"${targetPartitionSchema.fields.map(_.name).mkString("[", ",", "]")}. " +
118            "All partition columns having constant values need to appear before other " +
119            "partition columns that do not have an assigned constant value.")
120    }
121
122    assert(partitionList.take(staticPartitions.size).forall(_.isDefined))
123    val projectList =
124      sourceAttributes.take(targetAttributes.size - targetPartitionSchema.fields.size) ++
125        partitionList.take(staticPartitions.size).map(_.get) ++
126        sourceAttributes.takeRight(targetPartitionSchema.fields.size - staticPartitions.size)
127
128    projectList
129  }
130
131  override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
132    // If the InsertIntoTable command is for a partitioned HadoopFsRelation and
133    // the user has specified static partitions, we add a Project operator on top of the query
134    // to include those constant column values in the query result.
135    //
136    // Example:
137    // Let's say that we have a table "t", which is created by
138    // CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)
139    // The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3"
140    // will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3".
141    //
142    // Basically, we will put those partition columns having a assigned value back
143    // to the SELECT clause. The output of the SELECT clause is organized as
144    // normal_columns static_partitioning_columns dynamic_partitioning_columns.
145    // static_partitioning_columns are partitioning columns having assigned
146    // values in the PARTITION clause (e.g. b in the above example).
147    // dynamic_partitioning_columns are partitioning columns that do not assigned
148    // values in the PARTITION clause (e.g. c in the above example).
149    case insert @ logical.InsertIntoTable(
150      relation @ LogicalRelation(t: HadoopFsRelation, _, _), parts, query, overwrite, false)
151      if query.resolved && parts.exists(_._2.isDefined) =>
152
153      val projectList = convertStaticPartitions(
154        sourceAttributes = query.output,
155        providedPartitions = parts,
156        targetAttributes = relation.output,
157        targetPartitionSchema = t.partitionSchema)
158
159      // We will remove all assigned values to static partitions because they have been
160      // moved to the projectList.
161      insert.copy(partition = parts.map(p => (p._1, None)), child = Project(projectList, query))
162
163
164    case logical.InsertIntoTable(
165      l @ LogicalRelation(t: HadoopFsRelation, _, table), _, query, overwrite, false)
166        if query.resolved && t.schema.sameType(query.schema) =>
167
168      // Sanity checks
169      if (t.location.rootPaths.size != 1) {
170        throw new AnalysisException(
171          "Can only write data to relations with a single path.")
172      }
173
174      val outputPath = t.location.rootPaths.head
175      val inputPaths = query.collect {
176        case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths
177      }.flatten
178
179      val mode = if (overwrite.enabled) SaveMode.Overwrite else SaveMode.Append
180      if (overwrite.enabled && inputPaths.contains(outputPath)) {
181        throw new AnalysisException(
182          "Cannot overwrite a path that is also being read from.")
183      }
184
185      val partitionSchema = query.resolve(
186        t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
187      val partitionsTrackedByCatalog =
188        t.sparkSession.sessionState.conf.manageFilesourcePartitions &&
189        l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
190        l.catalogTable.get.tracksPartitionsInCatalog
191
192      var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
193      var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
194
195      val staticPartitionKeys: TablePartitionSpec = if (overwrite.enabled) {
196        overwrite.staticPartitionKeys.map { case (k, v) =>
197          (partitionSchema.map(_.name).find(_.equalsIgnoreCase(k)).get, v)
198        }
199      } else {
200        Map.empty
201      }
202
203      // When partitions are tracked by the catalog, compute all custom partition locations that
204      // may be relevant to the insertion job.
205      if (partitionsTrackedByCatalog) {
206        val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
207          l.catalogTable.get.identifier, Some(staticPartitionKeys))
208        initialMatchingPartitions = matchingPartitions.map(_.spec)
209        customPartitionLocations = getCustomPartitionLocations(
210          t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
211      }
212
213      // Callback for updating metastore partition metadata after the insertion job completes.
214      // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand
215      def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
216        if (partitionsTrackedByCatalog) {
217          val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions
218          if (newPartitions.nonEmpty) {
219            AlterTableAddPartitionCommand(
220              l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
221              ifNotExists = true).run(t.sparkSession)
222          }
223          if (overwrite.enabled) {
224            val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
225            if (deletedPartitions.nonEmpty) {
226              AlterTableDropPartitionCommand(
227                l.catalogTable.get.identifier, deletedPartitions.toSeq,
228                ifExists = true, purge = false,
229                retainData = true /* already deleted */).run(t.sparkSession)
230            }
231          }
232        }
233        t.location.refresh()
234      }
235
236      val insertCmd = InsertIntoHadoopFsRelationCommand(
237        outputPath,
238        staticPartitionKeys,
239        customPartitionLocations,
240        partitionSchema,
241        t.bucketSpec,
242        t.fileFormat,
243        refreshPartitionsCallback,
244        t.options,
245        query,
246        mode,
247        table)
248
249      insertCmd
250  }
251
252  /**
253   * Given a set of input partitions, returns those that have locations that differ from the
254   * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by
255   * the user.
256   *
257   * @return a mapping from partition specs to their custom locations
258   */
259  private def getCustomPartitionLocations(
260      spark: SparkSession,
261      table: CatalogTable,
262      basePath: Path,
263      partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
264    val hadoopConf = spark.sessionState.newHadoopConf
265    val fs = basePath.getFileSystem(hadoopConf)
266    val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory)
267    partitions.flatMap { p =>
268      val defaultLocation = qualifiedBasePath.suffix(
269        "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString
270      val catalogLocation = new Path(p.location).makeQualified(
271        fs.getUri, fs.getWorkingDirectory).toString
272      if (catalogLocation != defaultLocation) {
273        Some(p.spec -> catalogLocation)
274      } else {
275        None
276      }
277    }.toMap
278  }
279}
280
281
282/**
283 * Replaces [[SimpleCatalogRelation]] with data source table if its table property contains data
284 * source information.
285 */
286class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] {
287  private def readDataSourceTable(
288      sparkSession: SparkSession,
289      simpleCatalogRelation: SimpleCatalogRelation): LogicalPlan = {
290    val table = simpleCatalogRelation.catalogTable
291    val pathOption = table.storage.locationUri.map("path" -> _)
292    val dataSource =
293      DataSource(
294        sparkSession,
295        userSpecifiedSchema = Some(table.schema),
296        partitionColumns = table.partitionColumnNames,
297        bucketSpec = table.bucketSpec,
298        className = table.provider.get,
299        options = table.storage.properties ++ pathOption)
300
301    LogicalRelation(
302      dataSource.resolveRelation(checkFilesExist = false),
303      expectedOutputAttributes = Some(simpleCatalogRelation.output),
304      catalogTable = Some(table))
305  }
306
307  override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
308    case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _)
309        if DDLUtils.isDatasourceTable(s.metadata) =>
310      i.copy(table = readDataSourceTable(sparkSession, s))
311
312    case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) =>
313      readDataSourceTable(sparkSession, s)
314  }
315}
316
317
318/**
319 * A Strategy for planning scans over data sources defined using the sources API.
320 */
321object DataSourceStrategy extends Strategy with Logging {
322  def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
323    case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) =>
324      pruneFilterProjectRaw(
325        l,
326        projects,
327        filters,
328        (requestedColumns, allPredicates, _) =>
329          toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil
330
331    case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) =>
332      pruneFilterProject(
333        l,
334        projects,
335        filters,
336        (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil
337
338    case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _)) =>
339      pruneFilterProject(
340        l,
341        projects,
342        filters,
343        (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil
344
345    case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
346      RowDataSourceScanExec(
347        l.output,
348        toCatalystRDD(l, baseRelation.buildScan()),
349        baseRelation,
350        UnknownPartitioning(0),
351        Map.empty,
352        None) :: Nil
353
354    case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
355      part, query, overwrite, false) if part.isEmpty =>
356      ExecutedCommandExec(InsertIntoDataSourceCommand(l, query, overwrite)) :: Nil
357
358    case _ => Nil
359  }
360
361  // Get the bucket ID based on the bucketing values.
362  // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
363  def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
364    val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
365    mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
366    val bucketIdGeneration = UnsafeProjection.create(
367      HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
368      bucketColumn :: Nil)
369
370    bucketIdGeneration(mutableRow).getInt(0)
371  }
372
373  // Based on Public API.
374  private def pruneFilterProject(
375      relation: LogicalRelation,
376      projects: Seq[NamedExpression],
377      filterPredicates: Seq[Expression],
378      scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow]) = {
379    pruneFilterProjectRaw(
380      relation,
381      projects,
382      filterPredicates,
383      (requestedColumns, _, pushedFilters) => {
384        scanBuilder(requestedColumns, pushedFilters.toArray)
385      })
386  }
387
388  // Based on Catalyst expressions. The `scanBuilder` function accepts three arguments:
389  //
390  //  1. A `Seq[Attribute]`, containing all required column attributes. Used to handle relation
391  //     traits that support column pruning (e.g. `PrunedScan` and `PrunedFilteredScan`).
392  //
393  //  2. A `Seq[Expression]`, containing all gathered Catalyst filter expressions, only used for
394  //     `CatalystScan`.
395  //
396  //  3. A `Seq[Filter]`, containing all data source `Filter`s that are converted from (possibly a
397  //     subset of) Catalyst filter expressions and can be handled by `relation`.  Used to handle
398  //     relation traits (`CatalystScan` excluded) that support filter push-down (e.g.
399  //     `PrunedFilteredScan` and `HadoopFsRelation`).
400  //
401  // Note that 2 and 3 shouldn't be used together.
402  private def pruneFilterProjectRaw(
403    relation: LogicalRelation,
404    projects: Seq[NamedExpression],
405    filterPredicates: Seq[Expression],
406    scanBuilder: (Seq[Attribute], Seq[Expression], Seq[Filter]) => RDD[InternalRow]): SparkPlan = {
407
408    val projectSet = AttributeSet(projects.flatMap(_.references))
409    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
410
411    val candidatePredicates = filterPredicates.map { _ transform {
412      case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes.
413    }}
414
415    val (unhandledPredicates, pushedFilters, handledFilters) =
416      selectFilters(relation.relation, candidatePredicates)
417
418    // A set of column attributes that are only referenced by pushed down filters.  We can eliminate
419    // them from requested columns.
420    val handledSet = {
421      val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains)
422      val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references))
423      AttributeSet(handledPredicates.flatMap(_.references)) --
424        (projectSet ++ unhandledSet).map(relation.attributeMap)
425    }
426
427    // Combines all Catalyst filter `Expression`s that are either not convertible to data source
428    // `Filter`s or cannot be handled by `relation`.
429    val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And)
430
431    // These metadata values make scan plans uniquely identifiable for equality checking.
432    // TODO(SPARK-17701) using strings for equality checking is brittle
433    val metadata: Map[String, String] = {
434      val pairs = ArrayBuffer.empty[(String, String)]
435
436      // Mark filters which are handled by the underlying DataSource with an Astrisk
437      if (pushedFilters.nonEmpty) {
438        val markedFilters = for (filter <- pushedFilters) yield {
439            if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
440        }
441        pairs += ("PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
442      }
443      pairs += ("ReadSchema" ->
444        StructType.fromAttributes(projects.map(_.toAttribute)).catalogString)
445      pairs.toMap
446    }
447
448    if (projects.map(_.toAttribute) == projects &&
449        projectSet.size == projects.size &&
450        filterSet.subsetOf(projectSet)) {
451      // When it is possible to just use column pruning to get the right projection and
452      // when the columns of this projection are enough to evaluate all filter conditions,
453      // just do a scan followed by a filter, with no extra project.
454      val requestedColumns = projects
455        // Safe due to if above.
456        .asInstanceOf[Seq[Attribute]]
457        // Match original case of attributes.
458        .map(relation.attributeMap)
459        // Don't request columns that are only referenced by pushed filters.
460        .filterNot(handledSet.contains)
461
462      val scan = RowDataSourceScanExec(
463        projects.map(_.toAttribute),
464        scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
465        relation.relation, UnknownPartitioning(0), metadata,
466        relation.catalogTable.map(_.identifier))
467      filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)
468    } else {
469      // Don't request columns that are only referenced by pushed filters.
470      val requestedColumns =
471        (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq
472
473      val scan = RowDataSourceScanExec(
474        requestedColumns,
475        scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
476        relation.relation, UnknownPartitioning(0), metadata,
477        relation.catalogTable.map(_.identifier))
478      execution.ProjectExec(
479        projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan))
480    }
481  }
482
483  /**
484   * Convert RDD of Row into RDD of InternalRow with objects in catalyst types
485   */
486  private[this] def toCatalystRDD(
487      relation: LogicalRelation,
488      output: Seq[Attribute],
489      rdd: RDD[Row]): RDD[InternalRow] = {
490    if (relation.relation.needConversion) {
491      execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
492    } else {
493      rdd.asInstanceOf[RDD[InternalRow]]
494    }
495  }
496
497  /**
498   * Convert RDD of Row into RDD of InternalRow with objects in catalyst types
499   */
500  private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = {
501    toCatalystRDD(relation, relation.output, rdd)
502  }
503
504  /**
505   * Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
506   *
507   * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
508   */
509  protected[sql] def translateFilter(predicate: Expression): Option[Filter] = {
510    predicate match {
511      case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
512        Some(sources.EqualTo(a.name, convertToScala(v, t)))
513      case expressions.EqualTo(Literal(v, t), a: Attribute) =>
514        Some(sources.EqualTo(a.name, convertToScala(v, t)))
515
516      case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
517        Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
518      case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
519        Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
520
521      case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
522        Some(sources.GreaterThan(a.name, convertToScala(v, t)))
523      case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
524        Some(sources.LessThan(a.name, convertToScala(v, t)))
525
526      case expressions.LessThan(a: Attribute, Literal(v, t)) =>
527        Some(sources.LessThan(a.name, convertToScala(v, t)))
528      case expressions.LessThan(Literal(v, t), a: Attribute) =>
529        Some(sources.GreaterThan(a.name, convertToScala(v, t)))
530
531      case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
532        Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
533      case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
534        Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
535
536      case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
537        Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
538      case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
539        Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
540
541      case expressions.InSet(a: Attribute, set) =>
542        val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
543        Some(sources.In(a.name, set.toArray.map(toScala)))
544
545      // Because we only convert In to InSet in Optimizer when there are more than certain
546      // items. So it is possible we still get an In expression here that needs to be pushed
547      // down.
548      case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
549        val hSet = list.map(e => e.eval(EmptyRow))
550        val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
551        Some(sources.In(a.name, hSet.toArray.map(toScala)))
552
553      case expressions.IsNull(a: Attribute) =>
554        Some(sources.IsNull(a.name))
555      case expressions.IsNotNull(a: Attribute) =>
556        Some(sources.IsNotNull(a.name))
557
558      case expressions.And(left, right) =>
559        (translateFilter(left) ++ translateFilter(right)).reduceOption(sources.And)
560
561      case expressions.Or(left, right) =>
562        for {
563          leftFilter <- translateFilter(left)
564          rightFilter <- translateFilter(right)
565        } yield sources.Or(leftFilter, rightFilter)
566
567      case expressions.Not(child) =>
568        translateFilter(child).map(sources.Not)
569
570      case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
571        Some(sources.StringStartsWith(a.name, v.toString))
572
573      case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
574        Some(sources.StringEndsWith(a.name, v.toString))
575
576      case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) =>
577        Some(sources.StringContains(a.name, v.toString))
578
579      case _ => None
580    }
581  }
582
583  /**
584   * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s
585   * and can be handled by `relation`.
586   *
587   * @return A triplet of `Seq[Expression]`, `Seq[Filter]`, and `Seq[Filter]` . The first element
588   *         contains all Catalyst predicate [[Expression]]s that are either not convertible or
589   *         cannot be handled by `relation`. The second element contains all converted data source
590   *         [[Filter]]s that will be pushed down to the data source. The third element contains
591   *         all [[Filter]]s that are completely filtered at the DataSource.
592   */
593  protected[sql] def selectFilters(
594    relation: BaseRelation,
595    predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
596
597    // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are
598    // called `predicate`s, while all data source filters of type `sources.Filter` are simply called
599    // `filter`s.
600
601    // A map from original Catalyst expressions to corresponding translated data source filters.
602    // If a predicate is not in this map, it means it cannot be pushed down.
603    val translatedMap: Map[Expression, Filter] = predicates.flatMap { p =>
604      translateFilter(p).map(f => p -> f)
605    }.toMap
606
607    val pushedFilters: Seq[Filter] = translatedMap.values.toSeq
608
609    // Catalyst predicate expressions that cannot be converted to data source filters.
610    val nonconvertiblePredicates = predicates.filterNot(translatedMap.contains)
611
612    // Data source filters that cannot be handled by `relation`. An unhandled filter means
613    // the data source cannot guarantee the rows returned can pass the filter.
614    // As a result we must return it so Spark can plan an extra filter operator.
615    val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet
616    val unhandledPredicates = translatedMap.filter { case (p, f) =>
617      unhandledFilters.contains(f)
618    }.keys
619    val handledFilters = pushedFilters.toSet -- unhandledFilters
620
621    (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
622  }
623}
624