1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.util
19
20import java.io.Serializable
21import java.util.{PriorityQueue => JPriorityQueue}
22
23import scala.collection.JavaConverters._
24import scala.collection.generic.Growable
25
26/**
27 * Bounded priority queue. This class wraps the original PriorityQueue
28 * class and modifies it such that only the top K elements are retained.
29 * The top K elements are defined by an implicit Ordering[A].
30 */
31private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
32  extends Iterable[A] with Growable[A] with Serializable {
33
34  private val underlying = new JPriorityQueue[A](maxSize, ord)
35
36  override def iterator: Iterator[A] = underlying.iterator.asScala
37
38  override def size: Int = underlying.size
39
40  override def ++=(xs: TraversableOnce[A]): this.type = {
41    xs.foreach { this += _ }
42    this
43  }
44
45  override def +=(elem: A): this.type = {
46    if (size < maxSize) {
47      underlying.offer(elem)
48    } else {
49      maybeReplaceLowest(elem)
50    }
51    this
52  }
53
54  override def +=(elem1: A, elem2: A, elems: A*): this.type = {
55    this += elem1 += elem2 ++= elems
56  }
57
58  override def clear() { underlying.clear() }
59
60  private def maybeReplaceLowest(a: A): Boolean = {
61    val head = underlying.peek()
62    if (head != null && ord.gt(a, head)) {
63      underlying.poll()
64      underlying.offer(a)
65    } else {
66      false
67    }
68  }
69}
70