1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.rdd 19 20import java.util.concurrent.atomic.AtomicLong 21 22import scala.collection.mutable.ArrayBuffer 23import scala.concurrent.{ExecutionContext, Future} 24import scala.reflect.ClassTag 25 26import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter} 27import org.apache.spark.internal.Logging 28import org.apache.spark.util.ThreadUtils 29 30/** 31 * A set of asynchronous RDD actions available through an implicit conversion. 32 */ 33class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { 34 35 /** 36 * Returns a future for counting the number of elements in the RDD. 37 */ 38 def countAsync(): FutureAction[Long] = self.withScope { 39 val totalCount = new AtomicLong 40 self.context.submitJob( 41 self, 42 (iter: Iterator[T]) => { 43 var result = 0L 44 while (iter.hasNext) { 45 result += 1L 46 iter.next() 47 } 48 result 49 }, 50 Range(0, self.partitions.length), 51 (index: Int, data: Long) => totalCount.addAndGet(data), 52 totalCount.get()) 53 } 54 55 /** 56 * Returns a future for retrieving all elements of this RDD. 57 */ 58 def collectAsync(): FutureAction[Seq[T]] = self.withScope { 59 val results = new Array[Array[T]](self.partitions.length) 60 self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), 61 (index, data) => results(index) = data, results.flatten.toSeq) 62 } 63 64 /** 65 * Returns a future for retrieving the first num elements of the RDD. 66 */ 67 def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { 68 val callSite = self.context.getCallSite 69 val localProperties = self.context.getLocalProperties 70 // Cached thread pool to handle aggregation of subtasks. 71 implicit val executionContext = AsyncRDDActions.futureExecutionContext 72 val results = new ArrayBuffer[T] 73 val totalParts = self.partitions.length 74 75 /* 76 Recursively triggers jobs to scan partitions until either the requested 77 number of elements are retrieved, or the partitions to scan are exhausted. 78 This implementation is non-blocking, asynchronously handling the 79 results of each job and triggering the next job using callbacks on futures. 80 */ 81 def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] = 82 if (results.size >= num || partsScanned >= totalParts) { 83 Future.successful(results.toSeq) 84 } else { 85 // The number of partitions to try in this iteration. It is ok for this number to be 86 // greater than totalParts because we actually cap it at totalParts in runJob. 87 var numPartsToTry = 1L 88 if (partsScanned > 0) { 89 // If we didn't find any rows after the previous iteration, quadruple and retry. 90 // Otherwise, interpolate the number of partitions we need to try, but overestimate it 91 // by 50%. We also cap the estimation in the end. 92 if (results.size == 0) { 93 numPartsToTry = partsScanned * 4 94 } else { 95 // the left side of max is >=1 whenever partsScanned >= 2 96 numPartsToTry = Math.max(1, 97 (1.5 * num * partsScanned / results.size).toInt - partsScanned) 98 numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) 99 } 100 } 101 102 val left = num - results.size 103 val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) 104 105 val buf = new Array[Array[T]](p.size) 106 self.context.setCallSite(callSite) 107 self.context.setLocalProperties(localProperties) 108 val job = jobSubmitter.submitJob(self, 109 (it: Iterator[T]) => it.take(left).toArray, 110 p, 111 (index: Int, data: Array[T]) => buf(index) = data, 112 Unit) 113 job.flatMap { _ => 114 buf.foreach(results ++= _.take(num - results.size)) 115 continue(partsScanned + p.size) 116 } 117 } 118 119 new ComplexFutureAction[Seq[T]](continue(0)(_)) 120 } 121 122 /** 123 * Applies a function f to all elements of this RDD. 124 */ 125 def foreachAsync(f: T => Unit): FutureAction[Unit] = self.withScope { 126 val cleanF = self.context.clean(f) 127 self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), 128 (index, data) => Unit, Unit) 129 } 130 131 /** 132 * Applies a function f to each partition of this RDD. 133 */ 134 def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = self.withScope { 135 self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), 136 (index, data) => Unit, Unit) 137 } 138} 139 140private object AsyncRDDActions { 141 val futureExecutionContext = ExecutionContext.fromExecutorService( 142 ThreadUtils.newDaemonCachedThreadPool("AsyncRDDActions-future", 128)) 143} 144