1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.rdd
19
20import org.apache.spark.annotation.Since
21import org.apache.spark.TaskContext
22import org.apache.spark.internal.Logging
23import org.apache.spark.partial.BoundedDouble
24import org.apache.spark.partial.MeanEvaluator
25import org.apache.spark.partial.PartialResult
26import org.apache.spark.partial.SumEvaluator
27import org.apache.spark.util.StatCounter
28
29/**
30 * Extra functions available on RDDs of Doubles through an implicit conversion.
31 */
32class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
33  /** Add up the elements in this RDD. */
34  def sum(): Double = self.withScope {
35    self.fold(0.0)(_ + _)
36  }
37
38  /**
39   * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and
40   * count of the RDD's elements in one operation.
41   */
42  def stats(): StatCounter = self.withScope {
43    self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
44  }
45
46  /** Compute the mean of this RDD's elements. */
47  def mean(): Double = self.withScope {
48    stats().mean
49  }
50
51  /** Compute the population variance of this RDD's elements. */
52  def variance(): Double = self.withScope {
53    stats().variance
54  }
55
56  /** Compute the population standard deviation of this RDD's elements. */
57  def stdev(): Double = self.withScope {
58    stats().stdev
59  }
60
61  /**
62   * Compute the sample standard deviation of this RDD's elements (which corrects for bias in
63   * estimating the standard deviation by dividing by N-1 instead of N).
64   */
65  def sampleStdev(): Double = self.withScope {
66    stats().sampleStdev
67  }
68
69  /**
70   * Compute the sample variance of this RDD's elements (which corrects for bias in
71   * estimating the variance by dividing by N-1 instead of N).
72   */
73  def sampleVariance(): Double = self.withScope {
74    stats().sampleVariance
75  }
76
77  /**
78   * Compute the population standard deviation of this RDD's elements.
79   */
80  @Since("2.1.0")
81  def popStdev(): Double = self.withScope {
82    stats().popStdev
83  }
84
85  /**
86   * Compute the population variance of this RDD's elements.
87   */
88  @Since("2.1.0")
89  def popVariance(): Double = self.withScope {
90    stats().popVariance
91  }
92
93  /**
94   * Approximate operation to return the mean within a timeout.
95   */
96  def meanApprox(
97      timeout: Long,
98      confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope {
99    val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
100    val evaluator = new MeanEvaluator(self.partitions.length, confidence)
101    self.context.runApproximateJob(self, processPartition, evaluator, timeout)
102  }
103
104  /**
105   * Approximate operation to return the sum within a timeout.
106   */
107  def sumApprox(
108      timeout: Long,
109      confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope {
110    val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
111    val evaluator = new SumEvaluator(self.partitions.length, confidence)
112    self.context.runApproximateJob(self, processPartition, evaluator, timeout)
113  }
114
115  /**
116   * Compute a histogram of the data using bucketCount number of buckets evenly
117   *  spaced between the minimum and maximum of the RDD. For example if the min
118   *  value is 0 and the max is 100 and there are two buckets the resulting
119   *  buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
120   * If the RDD contains infinity, NaN throws an exception
121   * If the elements in RDD do not vary (max == min) always returns a single bucket.
122   */
123  def histogram(bucketCount: Int): (Array[Double], Array[Long]) = self.withScope {
124    // Scala's built-in range has issues. See #SI-8782
125    def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = {
126      val span = max - min
127      Range.Int(0, steps, 1).map(s => min + (s * span) / steps) :+ max
128    }
129    // Compute the minimum and the maximum
130    val (max: Double, min: Double) = self.mapPartitions { items =>
131      Iterator(items.foldRight(Double.NegativeInfinity,
132        Double.PositiveInfinity)((e: Double, x: (Double, Double)) =>
133        (x._1.max(e), x._2.min(e))))
134    }.reduce { (maxmin1, maxmin2) =>
135      (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
136    }
137    if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) {
138      throw new UnsupportedOperationException(
139        "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
140    }
141    val range = if (min != max) {
142      // Range.Double.inclusive(min, max, increment)
143      // The above code doesn't always work. See Scala bug #SI-8782.
144      // https://issues.scala-lang.org/browse/SI-8782
145      customRange(min, max, bucketCount)
146    } else {
147      List(min, min)
148    }
149    val buckets = range.toArray
150    (buckets, histogram(buckets, true))
151  }
152
153  /**
154   * Compute a histogram using the provided buckets. The buckets are all open
155   * to the right except for the last which is closed.
156   *  e.g. for the array
157   *  [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
158   *  e.g {@code <=x<10, 10<=x<20, 20<=x<=50}
159   *  And on the input of 1 and 50 we would have a histogram of 1, 0, 1
160   *
161   * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
162   * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets
163   * to true.
164   * buckets must be sorted and not contain any duplicates.
165   * buckets array must be at least two elements
166   * All NaN entries are treated the same. If you have a NaN bucket it must be
167   * the maximum value of the last position and all NaN entries will be counted
168   * in that bucket.
169   */
170  def histogram(
171      buckets: Array[Double],
172      evenBuckets: Boolean = false): Array[Long] = self.withScope {
173    if (buckets.length < 2) {
174      throw new IllegalArgumentException("buckets array must have at least two elements")
175    }
176    // The histogramPartition function computes the partail histogram for a given
177    // partition. The provided bucketFunction determines which bucket in the array
178    // to increment or returns None if there is no bucket. This is done so we can
179    // specialize for uniformly distributed buckets and save the O(log n) binary
180    // search cost.
181    def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
182        Iterator[Array[Long]] = {
183      val counters = new Array[Long](buckets.length - 1)
184      while (iter.hasNext) {
185        bucketFunction(iter.next()) match {
186          case Some(x: Int) => counters(x) += 1
187          case _ => // No-Op
188        }
189      }
190      Iterator(counters)
191    }
192    // Merge the counters.
193    def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
194      a1.indices.foreach(i => a1(i) += a2(i))
195      a1
196    }
197    // Basic bucket function. This works using Java's built in Array
198    // binary search. Takes log(size(buckets))
199    def basicBucketFunction(e: Double): Option[Int] = {
200      val location = java.util.Arrays.binarySearch(buckets, e)
201      if (location < 0) {
202        // If the location is less than 0 then the insertion point in the array
203        // to keep it sorted is -location-1
204        val insertionPoint = -location-1
205        // If we have to insert before the first element or after the last one
206        // its out of bounds.
207        // We do this rather than buckets.lengthCompare(insertionPoint)
208        // because Array[Double] fails to override it (for now).
209        if (insertionPoint > 0 && insertionPoint < buckets.length) {
210          Some(insertionPoint-1)
211        } else {
212          None
213        }
214      } else if (location < buckets.length - 1) {
215        // Exact match, just insert here
216        Some(location)
217      } else {
218        // Exact match to the last element
219        Some(location - 1)
220      }
221    }
222    // Determine the bucket function in constant time. Requires that buckets are evenly spaced
223    def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = {
224      // If our input is not a number unless the increment is also NaN then we fail fast
225      if (e.isNaN || e < min || e > max) {
226        None
227      } else {
228        // Compute ratio of e's distance along range to total range first, for better precision
229        val bucketNumber = (((e - min) / (max - min)) * count).toInt
230        // should be less than count, but will equal count if e == max, in which case
231        // it's part of the last end-range-inclusive bucket, so return count-1
232        Some(math.min(bucketNumber, count - 1))
233      }
234    }
235    // Decide which bucket function to pass to histogramPartition. We decide here
236    // rather than having a general function so that the decision need only be made
237    // once rather than once per shard
238    val bucketFunction = if (evenBuckets) {
239      fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _
240    } else {
241      basicBucketFunction _
242    }
243    if (self.partitions.length == 0) {
244      new Array[Long](buckets.length - 1)
245    } else {
246      // reduce() requires a non-empty RDD. This works because the mapPartitions will make
247      // non-empty partitions out of empty ones. But it doesn't handle the no-partitions case,
248      // which is below
249      self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
250    }
251  }
252
253}
254