1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.expressions.aggregate 19 20import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} 21import java.util 22 23import org.apache.spark.sql.AnalysisException 24import org.apache.spark.sql.catalyst.InternalRow 25import org.apache.spark.sql.catalyst.analysis.TypeCheckResult 26import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} 27import org.apache.spark.sql.catalyst.expressions._ 28import org.apache.spark.sql.catalyst.util._ 29import org.apache.spark.sql.types._ 30import org.apache.spark.util.collection.OpenHashMap 31 32/** 33 * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at 34 * the given percentage(s) with value range in [0.0, 1.0]. 35 * 36 * The operator is bound to the slower sort based aggregation path because the number of elements 37 * and their partial order cannot be determined in advance. Therefore we have to store all the 38 * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory 39 * Errors. 40 * 41 * @param child child expression that produce numeric column value with `child.eval(inputRow)` 42 * @param percentageExpression Expression that represents a single percentage value or an array of 43 * percentage values. Each percentage value must be in the range 44 * [0.0, 1.0]. 45 */ 46@ExpressionDescription( 47 usage = 48 """ 49 _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the 50 given percentage. The value of percentage must be between 0.0 and 1.0. 51 52 _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array 53 of numeric column `col` at the given percentage(s). Each value of the percentage array must 54 be between 0.0 and 1.0. 55 """) 56case class Percentile( 57 child: Expression, 58 percentageExpression: Expression, 59 mutableAggBufferOffset: Int = 0, 60 inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] { 61 62 def this(child: Expression, percentageExpression: Expression) = { 63 this(child, percentageExpression, 0, 0) 64 } 65 66 override def prettyName: String = "percentile" 67 68 override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile = 69 copy(mutableAggBufferOffset = newMutableAggBufferOffset) 70 71 override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile = 72 copy(inputAggBufferOffset = newInputAggBufferOffset) 73 74 // Mark as lazy so that percentageExpression is not evaluated during tree transformation. 75 @transient 76 private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] 77 78 @transient 79 private lazy val percentages = 80 (percentageExpression.dataType, percentageExpression.eval()) match { 81 case (_, num: Double) => Seq(num) 82 case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => 83 val numericArray = arrayData.toObjectArray(baseType) 84 numericArray.map { x => 85 baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq 86 case other => 87 throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages") 88 } 89 90 override def children: Seq[Expression] = child :: percentageExpression :: Nil 91 92 // Returns null for empty inputs 93 override def nullable: Boolean = true 94 95 override lazy val dataType: DataType = percentageExpression.dataType match { 96 case _: ArrayType => ArrayType(DoubleType, false) 97 case _ => DoubleType 98 } 99 100 override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match { 101 case _: ArrayType => Seq(NumericType, ArrayType) 102 case _ => Seq(NumericType, DoubleType) 103 } 104 105 // Check the inputTypes are valid, and the percentageExpression satisfies: 106 // 1. percentageExpression must be foldable; 107 // 2. percentages(s) must be in the range [0.0, 1.0]. 108 override def checkInputDataTypes(): TypeCheckResult = { 109 // Validate the inputTypes 110 val defaultCheck = super.checkInputDataTypes() 111 if (defaultCheck.isFailure) { 112 defaultCheck 113 } else if (!percentageExpression.foldable) { 114 // percentageExpression must be foldable 115 TypeCheckFailure("The percentage(s) must be a constant literal, " + 116 s"but got $percentageExpression") 117 } else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) { 118 // percentages(s) must be in the range [0.0, 1.0] 119 TypeCheckFailure("Percentage(s) must be between 0.0 and 1.0, " + 120 s"but got $percentageExpression") 121 } else { 122 TypeCheckSuccess 123 } 124 } 125 126 private def toDoubleValue(d: Any): Double = d match { 127 case d: Decimal => d.toDouble 128 case n: Number => n.doubleValue 129 } 130 131 override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { 132 // Initialize new counts map instance here. 133 new OpenHashMap[AnyRef, Long]() 134 } 135 136 override def update(buffer: OpenHashMap[AnyRef, Long], input: InternalRow): Unit = { 137 val key = child.eval(input).asInstanceOf[AnyRef] 138 139 // Null values are ignored in counts map. 140 if (key != null) { 141 buffer.changeValue(key, 1L, _ + 1L) 142 } 143 } 144 145 override def merge(buffer: OpenHashMap[AnyRef, Long], other: OpenHashMap[AnyRef, Long]): Unit = { 146 other.foreach { case (key, count) => 147 buffer.changeValue(key, count, _ + count) 148 } 149 } 150 151 override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { 152 generateOutput(getPercentiles(buffer)) 153 } 154 155 private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = { 156 if (buffer.isEmpty) { 157 return Seq.empty 158 } 159 160 val sortedCounts = buffer.toSeq.sortBy(_._1)( 161 child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) 162 val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { 163 case ((key1, count1), (key2, count2)) => (key2, count1 + count2) 164 }.tail 165 val maxPosition = accumlatedCounts.last._2 - 1 166 167 percentages.map { percentile => 168 getPercentile(accumlatedCounts, maxPosition * percentile) 169 } 170 } 171 172 private def generateOutput(results: Seq[Double]): Any = { 173 if (results.isEmpty) { 174 null 175 } else if (returnPercentileArray) { 176 new GenericArrayData(results) 177 } else { 178 results.head 179 } 180 } 181 182 /** 183 * Get the percentile value. 184 * 185 * This function has been based upon similar function from HIVE 186 * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. 187 */ 188 private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = { 189 // We may need to do linear interpolation to get the exact percentile 190 val lower = position.floor.toLong 191 val higher = position.ceil.toLong 192 193 // Use binary search to find the lower and the higher position. 194 val countsArray = aggreCounts.map(_._2).toArray[Long] 195 val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1) 196 val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1) 197 198 val lowerKey = aggreCounts(lowerIndex)._1 199 if (higher == lower) { 200 // no interpolation needed because position does not have a fraction 201 return toDoubleValue(lowerKey) 202 } 203 204 val higherKey = aggreCounts(higherIndex)._1 205 if (higherKey == lowerKey) { 206 // no interpolation needed because lower position and higher position has the same key 207 return toDoubleValue(lowerKey) 208 } 209 210 // Linear interpolation to get the exact percentile 211 (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey) 212 } 213 214 /** 215 * use a binary search to find the index of the position closest to the current value. 216 */ 217 private def binarySearchCount( 218 countsArray: Array[Long], start: Int, end: Int, value: Long): Int = { 219 util.Arrays.binarySearch(countsArray, 0, end, value) match { 220 case ix if ix < 0 => -(ix + 1) 221 case ix => ix 222 } 223 } 224 225 override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { 226 val buffer = new Array[Byte](4 << 10) // 4K 227 val bos = new ByteArrayOutputStream() 228 val out = new DataOutputStream(bos) 229 try { 230 val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType)) 231 // Write pairs in counts map to byte buffer. 232 obj.foreach { case (key, count) => 233 val row = InternalRow.apply(key, count) 234 val unsafeRow = projection.apply(row) 235 out.writeInt(unsafeRow.getSizeInBytes) 236 unsafeRow.writeToStream(out, buffer) 237 } 238 out.writeInt(-1) 239 out.flush() 240 241 bos.toByteArray 242 } finally { 243 out.close() 244 bos.close() 245 } 246 } 247 248 override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { 249 val bis = new ByteArrayInputStream(bytes) 250 val ins = new DataInputStream(bis) 251 try { 252 val counts = new OpenHashMap[AnyRef, Long] 253 // Read unsafeRow size and content in bytes. 254 var sizeOfNextRow = ins.readInt() 255 while (sizeOfNextRow >= 0) { 256 val bs = new Array[Byte](sizeOfNextRow) 257 ins.readFully(bs) 258 val row = new UnsafeRow(2) 259 row.pointTo(bs, sizeOfNextRow) 260 // Insert the pairs into counts map. 261 val key = row.get(0, child.dataType) 262 val count = row.get(1, LongType).asInstanceOf[Long] 263 counts.update(key, count) 264 sizeOfNextRow = ins.readInt() 265 } 266 267 counts 268 } finally { 269 ins.close() 270 bis.close() 271 } 272 } 273} 274