1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.rdd
19
20import java.util.concurrent.atomic.AtomicLong
21
22import scala.collection.mutable.ArrayBuffer
23import scala.concurrent.{ExecutionContext, Future}
24import scala.reflect.ClassTag
25
26import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter}
27import org.apache.spark.internal.Logging
28import org.apache.spark.util.ThreadUtils
29
30/**
31 * A set of asynchronous RDD actions available through an implicit conversion.
32 */
33class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
34
35  /**
36   * Returns a future for counting the number of elements in the RDD.
37   */
38  def countAsync(): FutureAction[Long] = self.withScope {
39    val totalCount = new AtomicLong
40    self.context.submitJob(
41      self,
42      (iter: Iterator[T]) => {
43        var result = 0L
44        while (iter.hasNext) {
45          result += 1L
46          iter.next()
47        }
48        result
49      },
50      Range(0, self.partitions.length),
51      (index: Int, data: Long) => totalCount.addAndGet(data),
52      totalCount.get())
53  }
54
55  /**
56   * Returns a future for retrieving all elements of this RDD.
57   */
58  def collectAsync(): FutureAction[Seq[T]] = self.withScope {
59    val results = new Array[Array[T]](self.partitions.length)
60    self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length),
61      (index, data) => results(index) = data, results.flatten.toSeq)
62  }
63
64  /**
65   * Returns a future for retrieving the first num elements of the RDD.
66   */
67  def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
68    val callSite = self.context.getCallSite
69    val localProperties = self.context.getLocalProperties
70    // Cached thread pool to handle aggregation of subtasks.
71    implicit val executionContext = AsyncRDDActions.futureExecutionContext
72    val results = new ArrayBuffer[T]
73    val totalParts = self.partitions.length
74
75    /*
76      Recursively triggers jobs to scan partitions until either the requested
77      number of elements are retrieved, or the partitions to scan are exhausted.
78      This implementation is non-blocking, asynchronously handling the
79      results of each job and triggering the next job using callbacks on futures.
80     */
81    def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] =
82      if (results.size >= num || partsScanned >= totalParts) {
83        Future.successful(results.toSeq)
84      } else {
85        // The number of partitions to try in this iteration. It is ok for this number to be
86        // greater than totalParts because we actually cap it at totalParts in runJob.
87        var numPartsToTry = 1L
88        if (partsScanned > 0) {
89          // If we didn't find any rows after the previous iteration, quadruple and retry.
90          // Otherwise, interpolate the number of partitions we need to try, but overestimate it
91          // by 50%. We also cap the estimation in the end.
92          if (results.size == 0) {
93            numPartsToTry = partsScanned * 4
94          } else {
95            // the left side of max is >=1 whenever partsScanned >= 2
96            numPartsToTry = Math.max(1,
97              (1.5 * num * partsScanned / results.size).toInt - partsScanned)
98            numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
99          }
100        }
101
102        val left = num - results.size
103        val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
104
105        val buf = new Array[Array[T]](p.size)
106        self.context.setCallSite(callSite)
107        self.context.setLocalProperties(localProperties)
108        val job = jobSubmitter.submitJob(self,
109          (it: Iterator[T]) => it.take(left).toArray,
110          p,
111          (index: Int, data: Array[T]) => buf(index) = data,
112          Unit)
113        job.flatMap { _ =>
114          buf.foreach(results ++= _.take(num - results.size))
115          continue(partsScanned + p.size)
116        }
117      }
118
119    new ComplexFutureAction[Seq[T]](continue(0)(_))
120  }
121
122  /**
123   * Applies a function f to all elements of this RDD.
124   */
125  def foreachAsync(f: T => Unit): FutureAction[Unit] = self.withScope {
126    val cleanF = self.context.clean(f)
127    self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length),
128      (index, data) => Unit, Unit)
129  }
130
131  /**
132   * Applies a function f to each partition of this RDD.
133   */
134  def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = self.withScope {
135    self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length),
136      (index, data) => Unit, Unit)
137  }
138}
139
140private object AsyncRDDActions {
141  val futureExecutionContext = ExecutionContext.fromExecutorService(
142    ThreadUtils.newDaemonCachedThreadPool("AsyncRDDActions-future", 128))
143}
144