1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.expressions
19
20import org.apache.spark.sql.Row
21import org.apache.spark.sql.catalyst.InternalRow
22import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
23import org.apache.spark.sql.types._
24import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
25
26/**
27 * An extended version of [[InternalRow]] that implements all special getters, toString
28 * and equals/hashCode by `genericGet`.
29 */
30trait BaseGenericInternalRow extends InternalRow {
31
32  protected def genericGet(ordinal: Int): Any
33
34  // default implementation (slow)
35  private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
36  override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
37  override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal)
38  override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
39  override def getByte(ordinal: Int): Byte = getAs(ordinal)
40  override def getShort(ordinal: Int): Short = getAs(ordinal)
41  override def getInt(ordinal: Int): Int = getAs(ordinal)
42  override def getLong(ordinal: Int): Long = getAs(ordinal)
43  override def getFloat(ordinal: Int): Float = getAs(ordinal)
44  override def getDouble(ordinal: Int): Double = getAs(ordinal)
45  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
46  override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
47  override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
48  override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
49  override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
50  override def getMap(ordinal: Int): MapData = getAs(ordinal)
51  override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
52
53  override def anyNull: Boolean = {
54    val len = numFields
55    var i = 0
56    while (i < len) {
57      if (isNullAt(i)) { return true }
58      i += 1
59    }
60    false
61  }
62
63  override def toString: String = {
64    if (numFields == 0) {
65      "[empty row]"
66    } else {
67      val sb = new StringBuilder
68      sb.append("[")
69      sb.append(genericGet(0))
70      val len = numFields
71      var i = 1
72      while (i < len) {
73        sb.append(",")
74        sb.append(genericGet(i))
75        i += 1
76      }
77      sb.append("]")
78      sb.toString()
79    }
80  }
81
82  override def equals(o: Any): Boolean = {
83    if (!o.isInstanceOf[BaseGenericInternalRow]) {
84      return false
85    }
86
87    val other = o.asInstanceOf[BaseGenericInternalRow]
88    if (other eq null) {
89      return false
90    }
91
92    val len = numFields
93    if (len != other.numFields) {
94      return false
95    }
96
97    var i = 0
98    while (i < len) {
99      if (isNullAt(i) != other.isNullAt(i)) {
100        return false
101      }
102      if (!isNullAt(i)) {
103        val o1 = genericGet(i)
104        val o2 = other.genericGet(i)
105        o1 match {
106          case b1: Array[Byte] =>
107            if (!o2.isInstanceOf[Array[Byte]] ||
108              !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
109              return false
110            }
111          case f1: Float if java.lang.Float.isNaN(f1) =>
112            if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
113              return false
114            }
115          case d1: Double if java.lang.Double.isNaN(d1) =>
116            if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
117              return false
118            }
119          case _ => if (o1 != o2) {
120            return false
121          }
122        }
123      }
124      i += 1
125    }
126    true
127  }
128
129  // Custom hashCode function that matches the efficient code generated version.
130  override def hashCode: Int = {
131    var result: Int = 37
132    var i = 0
133    val len = numFields
134    while (i < len) {
135      val update: Int =
136        if (isNullAt(i)) {
137          0
138        } else {
139          genericGet(i) match {
140            case b: Boolean => if (b) 0 else 1
141            case b: Byte => b.toInt
142            case s: Short => s.toInt
143            case i: Int => i
144            case l: Long => (l ^ (l >>> 32)).toInt
145            case f: Float => java.lang.Float.floatToIntBits(f)
146            case d: Double =>
147              val b = java.lang.Double.doubleToLongBits(d)
148              (b ^ (b >>> 32)).toInt
149            case a: Array[Byte] => java.util.Arrays.hashCode(a)
150            case other => other.hashCode()
151          }
152        }
153      result = 37 * result + update
154      i += 1
155    }
156    result
157  }
158}
159
160/**
161 * A row implementation that uses an array of objects as the underlying storage.  Note that, while
162 * the array is not copied, and thus could technically be mutated after creation, this is not
163 * allowed.
164 */
165class GenericRow(protected[sql] val values: Array[Any]) extends Row {
166  /** No-arg constructor for serialization. */
167  protected def this() = this(null)
168
169  def this(size: Int) = this(new Array[Any](size))
170
171  override def length: Int = values.length
172
173  override def get(i: Int): Any = values(i)
174
175  override def toSeq: Seq[Any] = values.clone()
176
177  override def copy(): GenericRow = this
178}
179
180class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
181  extends GenericRow(values) {
182
183  /** No-arg constructor for serialization. */
184  protected def this() = this(null, null)
185
186  override def fieldIndex(name: String): Int = schema.fieldIndex(name)
187}
188
189/**
190 * An internal row implementation that uses an array of objects as the underlying storage.
191 * Note that, while the array is not copied, and thus could technically be mutated after creation,
192 * this is not allowed.
193 */
194class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow {
195  /** No-arg constructor for serialization. */
196  protected def this() = this(null)
197
198  def this(size: Int) = this(new Array[Any](size))
199
200  override protected def genericGet(ordinal: Int) = values(ordinal)
201
202  override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values.clone()
203
204  override def numFields: Int = values.length
205
206  override def setNullAt(i: Int): Unit = { values(i) = null}
207
208  override def update(i: Int, value: Any): Unit = { values(i) = value }
209
210  override def copy(): GenericInternalRow = this
211}
212