1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.expressions.aggregate
19
20import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
21import java.util
22
23import org.apache.spark.sql.AnalysisException
24import org.apache.spark.sql.catalyst.InternalRow
25import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
27import org.apache.spark.sql.catalyst.expressions._
28import org.apache.spark.sql.catalyst.util._
29import org.apache.spark.sql.types._
30import org.apache.spark.util.collection.OpenHashMap
31
32/**
33 * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
34 * the given percentage(s) with value range in [0.0, 1.0].
35 *
36 * The operator is bound to the slower sort based aggregation path because the number of elements
37 * and their partial order cannot be determined in advance. Therefore we have to store all the
38 * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory
39 * Errors.
40 *
41 * @param child child expression that produce numeric column value with `child.eval(inputRow)`
42 * @param percentageExpression Expression that represents a single percentage value or an array of
43 *                             percentage values. Each percentage value must be in the range
44 *                             [0.0, 1.0].
45 */
46@ExpressionDescription(
47  usage =
48    """
49      _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the
50      given percentage. The value of percentage must be between 0.0 and 1.0.
51
52      _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array
53      of numeric column `col` at the given percentage(s). Each value of the percentage array must
54      be between 0.0 and 1.0.
55    """)
56case class Percentile(
57  child: Expression,
58  percentageExpression: Expression,
59  mutableAggBufferOffset: Int = 0,
60  inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] {
61
62  def this(child: Expression, percentageExpression: Expression) = {
63    this(child, percentageExpression, 0, 0)
64  }
65
66  override def prettyName: String = "percentile"
67
68  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile =
69    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
70
71  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile =
72    copy(inputAggBufferOffset = newInputAggBufferOffset)
73
74  // Mark as lazy so that percentageExpression is not evaluated during tree transformation.
75  @transient
76  private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType]
77
78  @transient
79  private lazy val percentages =
80    (percentageExpression.dataType, percentageExpression.eval()) match {
81      case (_, num: Double) => Seq(num)
82      case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
83        val numericArray = arrayData.toObjectArray(baseType)
84        numericArray.map { x =>
85          baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq
86      case other =>
87        throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages")
88  }
89
90  override def children: Seq[Expression] = child :: percentageExpression :: Nil
91
92  // Returns null for empty inputs
93  override def nullable: Boolean = true
94
95  override lazy val dataType: DataType = percentageExpression.dataType match {
96    case _: ArrayType => ArrayType(DoubleType, false)
97    case _ => DoubleType
98  }
99
100  override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
101    case _: ArrayType => Seq(NumericType, ArrayType)
102    case _ => Seq(NumericType, DoubleType)
103  }
104
105  // Check the inputTypes are valid, and the percentageExpression satisfies:
106  // 1. percentageExpression must be foldable;
107  // 2. percentages(s) must be in the range [0.0, 1.0].
108  override def checkInputDataTypes(): TypeCheckResult = {
109    // Validate the inputTypes
110    val defaultCheck = super.checkInputDataTypes()
111    if (defaultCheck.isFailure) {
112      defaultCheck
113    } else if (!percentageExpression.foldable) {
114      // percentageExpression must be foldable
115      TypeCheckFailure("The percentage(s) must be a constant literal, " +
116        s"but got $percentageExpression")
117    } else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) {
118      // percentages(s) must be in the range [0.0, 1.0]
119      TypeCheckFailure("Percentage(s) must be between 0.0 and 1.0, " +
120        s"but got $percentageExpression")
121    } else {
122      TypeCheckSuccess
123    }
124  }
125
126  private def toDoubleValue(d: Any): Double = d match {
127    case d: Decimal => d.toDouble
128    case n: Number => n.doubleValue
129  }
130
131  override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
132    // Initialize new counts map instance here.
133    new OpenHashMap[AnyRef, Long]()
134  }
135
136  override def update(buffer: OpenHashMap[AnyRef, Long], input: InternalRow): Unit = {
137    val key = child.eval(input).asInstanceOf[AnyRef]
138
139    // Null values are ignored in counts map.
140    if (key != null) {
141      buffer.changeValue(key, 1L, _ + 1L)
142    }
143  }
144
145  override def merge(buffer: OpenHashMap[AnyRef, Long], other: OpenHashMap[AnyRef, Long]): Unit = {
146    other.foreach { case (key, count) =>
147      buffer.changeValue(key, count, _ + count)
148    }
149  }
150
151  override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
152    generateOutput(getPercentiles(buffer))
153  }
154
155  private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
156    if (buffer.isEmpty) {
157      return Seq.empty
158    }
159
160    val sortedCounts = buffer.toSeq.sortBy(_._1)(
161      child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
162    val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
163      case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
164    }.tail
165    val maxPosition = accumlatedCounts.last._2 - 1
166
167    percentages.map { percentile =>
168      getPercentile(accumlatedCounts, maxPosition * percentile)
169    }
170  }
171
172  private def generateOutput(results: Seq[Double]): Any = {
173    if (results.isEmpty) {
174      null
175    } else if (returnPercentileArray) {
176      new GenericArrayData(results)
177    } else {
178      results.head
179    }
180  }
181
182  /**
183   * Get the percentile value.
184   *
185   * This function has been based upon similar function from HIVE
186   * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
187   */
188  private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
189    // We may need to do linear interpolation to get the exact percentile
190    val lower = position.floor.toLong
191    val higher = position.ceil.toLong
192
193    // Use binary search to find the lower and the higher position.
194    val countsArray = aggreCounts.map(_._2).toArray[Long]
195    val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1)
196    val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1)
197
198    val lowerKey = aggreCounts(lowerIndex)._1
199    if (higher == lower) {
200      // no interpolation needed because position does not have a fraction
201      return toDoubleValue(lowerKey)
202    }
203
204    val higherKey = aggreCounts(higherIndex)._1
205    if (higherKey == lowerKey) {
206      // no interpolation needed because lower position and higher position has the same key
207      return toDoubleValue(lowerKey)
208    }
209
210    // Linear interpolation to get the exact percentile
211    (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
212  }
213
214  /**
215   * use a binary search to find the index of the position closest to the current value.
216   */
217  private def binarySearchCount(
218      countsArray: Array[Long], start: Int, end: Int, value: Long): Int = {
219    util.Arrays.binarySearch(countsArray, 0, end, value) match {
220      case ix if ix < 0 => -(ix + 1)
221      case ix => ix
222    }
223  }
224
225  override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
226    val buffer = new Array[Byte](4 << 10)  // 4K
227    val bos = new ByteArrayOutputStream()
228    val out = new DataOutputStream(bos)
229    try {
230      val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType))
231      // Write pairs in counts map to byte buffer.
232      obj.foreach { case (key, count) =>
233        val row = InternalRow.apply(key, count)
234        val unsafeRow = projection.apply(row)
235        out.writeInt(unsafeRow.getSizeInBytes)
236        unsafeRow.writeToStream(out, buffer)
237      }
238      out.writeInt(-1)
239      out.flush()
240
241      bos.toByteArray
242    } finally {
243      out.close()
244      bos.close()
245    }
246  }
247
248  override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
249    val bis = new ByteArrayInputStream(bytes)
250    val ins = new DataInputStream(bis)
251    try {
252      val counts = new OpenHashMap[AnyRef, Long]
253      // Read unsafeRow size and content in bytes.
254      var sizeOfNextRow = ins.readInt()
255      while (sizeOfNextRow >= 0) {
256        val bs = new Array[Byte](sizeOfNextRow)
257        ins.readFully(bs)
258        val row = new UnsafeRow(2)
259        row.pointTo(bs, sizeOfNextRow)
260        // Insert the pairs into counts map.
261        val key = row.get(0, child.dataType)
262        val count = row.get(1, LongType).asInstanceOf[Long]
263        counts.update(key, count)
264        sizeOfNextRow = ins.readInt()
265      }
266
267      counts
268    } finally {
269      ins.close()
270      bis.close()
271    }
272  }
273}
274