1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.rdd 19 20import org.apache.spark.annotation.Since 21import org.apache.spark.TaskContext 22import org.apache.spark.internal.Logging 23import org.apache.spark.partial.BoundedDouble 24import org.apache.spark.partial.MeanEvaluator 25import org.apache.spark.partial.PartialResult 26import org.apache.spark.partial.SumEvaluator 27import org.apache.spark.util.StatCounter 28 29/** 30 * Extra functions available on RDDs of Doubles through an implicit conversion. 31 */ 32class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { 33 /** Add up the elements in this RDD. */ 34 def sum(): Double = self.withScope { 35 self.fold(0.0)(_ + _) 36 } 37 38 /** 39 * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and 40 * count of the RDD's elements in one operation. 41 */ 42 def stats(): StatCounter = self.withScope { 43 self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) 44 } 45 46 /** Compute the mean of this RDD's elements. */ 47 def mean(): Double = self.withScope { 48 stats().mean 49 } 50 51 /** Compute the population variance of this RDD's elements. */ 52 def variance(): Double = self.withScope { 53 stats().variance 54 } 55 56 /** Compute the population standard deviation of this RDD's elements. */ 57 def stdev(): Double = self.withScope { 58 stats().stdev 59 } 60 61 /** 62 * Compute the sample standard deviation of this RDD's elements (which corrects for bias in 63 * estimating the standard deviation by dividing by N-1 instead of N). 64 */ 65 def sampleStdev(): Double = self.withScope { 66 stats().sampleStdev 67 } 68 69 /** 70 * Compute the sample variance of this RDD's elements (which corrects for bias in 71 * estimating the variance by dividing by N-1 instead of N). 72 */ 73 def sampleVariance(): Double = self.withScope { 74 stats().sampleVariance 75 } 76 77 /** 78 * Compute the population standard deviation of this RDD's elements. 79 */ 80 @Since("2.1.0") 81 def popStdev(): Double = self.withScope { 82 stats().popStdev 83 } 84 85 /** 86 * Compute the population variance of this RDD's elements. 87 */ 88 @Since("2.1.0") 89 def popVariance(): Double = self.withScope { 90 stats().popVariance 91 } 92 93 /** 94 * Approximate operation to return the mean within a timeout. 95 */ 96 def meanApprox( 97 timeout: Long, 98 confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { 99 val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) 100 val evaluator = new MeanEvaluator(self.partitions.length, confidence) 101 self.context.runApproximateJob(self, processPartition, evaluator, timeout) 102 } 103 104 /** 105 * Approximate operation to return the sum within a timeout. 106 */ 107 def sumApprox( 108 timeout: Long, 109 confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { 110 val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) 111 val evaluator = new SumEvaluator(self.partitions.length, confidence) 112 self.context.runApproximateJob(self, processPartition, evaluator, timeout) 113 } 114 115 /** 116 * Compute a histogram of the data using bucketCount number of buckets evenly 117 * spaced between the minimum and maximum of the RDD. For example if the min 118 * value is 0 and the max is 100 and there are two buckets the resulting 119 * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1 120 * If the RDD contains infinity, NaN throws an exception 121 * If the elements in RDD do not vary (max == min) always returns a single bucket. 122 */ 123 def histogram(bucketCount: Int): (Array[Double], Array[Long]) = self.withScope { 124 // Scala's built-in range has issues. See #SI-8782 125 def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { 126 val span = max - min 127 Range.Int(0, steps, 1).map(s => min + (s * span) / steps) :+ max 128 } 129 // Compute the minimum and the maximum 130 val (max: Double, min: Double) = self.mapPartitions { items => 131 Iterator(items.foldRight(Double.NegativeInfinity, 132 Double.PositiveInfinity)((e: Double, x: (Double, Double)) => 133 (x._1.max(e), x._2.min(e)))) 134 }.reduce { (maxmin1, maxmin2) => 135 (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) 136 } 137 if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) { 138 throw new UnsupportedOperationException( 139 "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") 140 } 141 val range = if (min != max) { 142 // Range.Double.inclusive(min, max, increment) 143 // The above code doesn't always work. See Scala bug #SI-8782. 144 // https://issues.scala-lang.org/browse/SI-8782 145 customRange(min, max, bucketCount) 146 } else { 147 List(min, min) 148 } 149 val buckets = range.toArray 150 (buckets, histogram(buckets, true)) 151 } 152 153 /** 154 * Compute a histogram using the provided buckets. The buckets are all open 155 * to the right except for the last which is closed. 156 * e.g. for the array 157 * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] 158 * e.g {@code <=x<10, 10<=x<20, 20<=x<=50} 159 * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 160 * 161 * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched 162 * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets 163 * to true. 164 * buckets must be sorted and not contain any duplicates. 165 * buckets array must be at least two elements 166 * All NaN entries are treated the same. If you have a NaN bucket it must be 167 * the maximum value of the last position and all NaN entries will be counted 168 * in that bucket. 169 */ 170 def histogram( 171 buckets: Array[Double], 172 evenBuckets: Boolean = false): Array[Long] = self.withScope { 173 if (buckets.length < 2) { 174 throw new IllegalArgumentException("buckets array must have at least two elements") 175 } 176 // The histogramPartition function computes the partail histogram for a given 177 // partition. The provided bucketFunction determines which bucket in the array 178 // to increment or returns None if there is no bucket. This is done so we can 179 // specialize for uniformly distributed buckets and save the O(log n) binary 180 // search cost. 181 def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): 182 Iterator[Array[Long]] = { 183 val counters = new Array[Long](buckets.length - 1) 184 while (iter.hasNext) { 185 bucketFunction(iter.next()) match { 186 case Some(x: Int) => counters(x) += 1 187 case _ => // No-Op 188 } 189 } 190 Iterator(counters) 191 } 192 // Merge the counters. 193 def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = { 194 a1.indices.foreach(i => a1(i) += a2(i)) 195 a1 196 } 197 // Basic bucket function. This works using Java's built in Array 198 // binary search. Takes log(size(buckets)) 199 def basicBucketFunction(e: Double): Option[Int] = { 200 val location = java.util.Arrays.binarySearch(buckets, e) 201 if (location < 0) { 202 // If the location is less than 0 then the insertion point in the array 203 // to keep it sorted is -location-1 204 val insertionPoint = -location-1 205 // If we have to insert before the first element or after the last one 206 // its out of bounds. 207 // We do this rather than buckets.lengthCompare(insertionPoint) 208 // because Array[Double] fails to override it (for now). 209 if (insertionPoint > 0 && insertionPoint < buckets.length) { 210 Some(insertionPoint-1) 211 } else { 212 None 213 } 214 } else if (location < buckets.length - 1) { 215 // Exact match, just insert here 216 Some(location) 217 } else { 218 // Exact match to the last element 219 Some(location - 1) 220 } 221 } 222 // Determine the bucket function in constant time. Requires that buckets are evenly spaced 223 def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { 224 // If our input is not a number unless the increment is also NaN then we fail fast 225 if (e.isNaN || e < min || e > max) { 226 None 227 } else { 228 // Compute ratio of e's distance along range to total range first, for better precision 229 val bucketNumber = (((e - min) / (max - min)) * count).toInt 230 // should be less than count, but will equal count if e == max, in which case 231 // it's part of the last end-range-inclusive bucket, so return count-1 232 Some(math.min(bucketNumber, count - 1)) 233 } 234 } 235 // Decide which bucket function to pass to histogramPartition. We decide here 236 // rather than having a general function so that the decision need only be made 237 // once rather than once per shard 238 val bucketFunction = if (evenBuckets) { 239 fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ 240 } else { 241 basicBucketFunction _ 242 } 243 if (self.partitions.length == 0) { 244 new Array[Long](buckets.length - 1) 245 } else { 246 // reduce() requires a non-empty RDD. This works because the mapPartitions will make 247 // non-empty partitions out of empty ones. But it doesn't handle the no-partitions case, 248 // which is below 249 self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) 250 } 251 } 252 253} 254