1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.util
19
20import java.lang.management.ManagementFactory
21import java.lang.reflect.{Field, Modifier}
22import java.util.{IdentityHashMap, Random}
23
24import scala.collection.mutable.ArrayBuffer
25import scala.runtime.ScalaRunTime
26
27import com.google.common.collect.MapMaker
28
29import org.apache.spark.annotation.DeveloperApi
30import org.apache.spark.internal.Logging
31import org.apache.spark.util.collection.OpenHashSet
32
33/**
34 * A trait that allows a class to give [[SizeEstimator]] more accurate size estimation.
35 * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first.
36 * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size
37 * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work.
38 * The difference between a [[KnownSizeEstimation]] and
39 * [[org.apache.spark.util.collection.SizeTracker]] is that, a
40 * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to
41 * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without
42 * using [[SizeEstimator]].
43 */
44private[spark] trait KnownSizeEstimation {
45  def estimatedSize: Long
46}
47
48/**
49 * :: DeveloperApi ::
50 * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in
51 * memory-aware caches.
52 *
53 * Based on the following JavaWorld article:
54 * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
55 */
56@DeveloperApi
57object SizeEstimator extends Logging {
58
59  /**
60   * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate
61   * includes space taken up by objects referenced by the given object, their references, and so on
62   * and so forth.
63   *
64   * This is useful for determining the amount of heap space a broadcast variable will occupy on
65   * each executor or the amount of space each object will take when caching objects in
66   * deserialized form. This is not the same as the serialized size of the object, which will
67   * typically be much smaller.
68   */
69  def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef])
70
71  // Sizes of primitive types
72  private val BYTE_SIZE = 1
73  private val BOOLEAN_SIZE = 1
74  private val CHAR_SIZE = 2
75  private val SHORT_SIZE = 2
76  private val INT_SIZE = 4
77  private val LONG_SIZE = 8
78  private val FLOAT_SIZE = 4
79  private val DOUBLE_SIZE = 8
80
81  // Fields can be primitive types, sizes are: 1, 2, 4, 8. Or fields can be pointers. The size of
82  // a pointer is 4 or 8 depending on the JVM (32-bit or 64-bit) and UseCompressedOops flag.
83  // The sizes should be in descending order, as we will use that information for fields placement.
84  private val fieldSizes = List(8, 4, 2, 1)
85
86  // Alignment boundary for objects
87  // TODO: Is this arch dependent ?
88  private val ALIGN_SIZE = 8
89
90  // A cache of ClassInfo objects for each class
91  // We use weakKeys to allow GC of dynamically created classes
92  private val classInfos = new MapMaker().weakKeys().makeMap[Class[_], ClassInfo]()
93
94  // Object and pointer sizes are arch dependent
95  private var is64bit = false
96
97  // Size of an object reference
98  // Based on https://wikis.oracle.com/display/HotSpotInternals/CompressedOops
99  private var isCompressedOops = false
100  private var pointerSize = 4
101
102  // Minimum size of a java.lang.Object
103  private var objectSize = 8
104
105  initialize()
106
107  // Sets object size, pointer size based on architecture and CompressedOops settings
108  // from the JVM.
109  private def initialize() {
110    val arch = System.getProperty("os.arch")
111    is64bit = arch.contains("64") || arch.contains("s390x")
112    isCompressedOops = getIsCompressedOops
113
114    objectSize = if (!is64bit) 8 else {
115      if (!isCompressedOops) {
116        16
117      } else {
118        12
119      }
120    }
121    pointerSize = if (is64bit && !isCompressedOops) 8 else 4
122    classInfos.clear()
123    classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil))
124  }
125
126  private def getIsCompressedOops: Boolean = {
127    // This is only used by tests to override the detection of compressed oops. The test
128    // actually uses a system property instead of a SparkConf, so we'll stick with that.
129    if (System.getProperty("spark.test.useCompressedOops") != null) {
130      return System.getProperty("spark.test.useCompressedOops").toBoolean
131    }
132
133    // java.vm.info provides compressed ref info for IBM JDKs
134    if (System.getProperty("java.vendor").contains("IBM")) {
135      return System.getProperty("java.vm.info").contains("Compressed Ref")
136    }
137
138    try {
139      val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
140      val server = ManagementFactory.getPlatformMBeanServer()
141
142      // NOTE: This should throw an exception in non-Sun JVMs
143      // scalastyle:off classforname
144      val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean")
145      val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
146          Class.forName("java.lang.String"))
147      // scalastyle:on classforname
148
149      val bean = ManagementFactory.newPlatformMXBeanProxy(server,
150        hotSpotMBeanName, hotSpotMBeanClass)
151      // TODO: We could use reflection on the VMOption returned ?
152      getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
153    } catch {
154      case e: Exception =>
155        // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
156        val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
157        val guessInWords = if (guess) "yes" else "not"
158        logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
159        return guess
160    }
161  }
162
163  /**
164   * The state of an ongoing size estimation. Contains a stack of objects to visit as well as an
165   * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects
166   * to visit.
167   */
168  private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) {
169    val stack = new ArrayBuffer[AnyRef]
170    var size = 0L
171
172    def enqueue(obj: AnyRef) {
173      if (obj != null && !visited.containsKey(obj)) {
174        visited.put(obj, null)
175        stack += obj
176      }
177    }
178
179    def isFinished(): Boolean = stack.isEmpty
180
181    def dequeue(): AnyRef = {
182      val elem = stack.last
183      stack.trimEnd(1)
184      elem
185    }
186  }
187
188  /**
189   * Cached information about each class. We remember two things: the "shell size" of the class
190   * (size of all non-static fields plus the java.lang.Object size), and any fields that are
191   * pointers to objects.
192   */
193  private class ClassInfo(
194    val shellSize: Long,
195    val pointerFields: List[Field]) {}
196
197  private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = {
198    val state = new SearchState(visited)
199    state.enqueue(obj)
200    while (!state.isFinished) {
201      visitSingleObject(state.dequeue(), state)
202    }
203    state.size
204  }
205
206  private def visitSingleObject(obj: AnyRef, state: SearchState) {
207    val cls = obj.getClass
208    if (cls.isArray) {
209      visitArray(obj, cls, state)
210    } else if (cls.getName.startsWith("scala.reflect")) {
211      // Many objects in the scala.reflect package reference global reflection objects which, in
212      // turn, reference many other large global objects. Do nothing in this case.
213    } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) {
214      // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses
215      // the size estimator since it references the whole REPL. Do nothing in this case. In
216      // general all ClassLoaders and Classes will be shared between objects anyway.
217    } else {
218      obj match {
219        case s: KnownSizeEstimation =>
220          state.size += s.estimatedSize
221        case _ =>
222          val classInfo = getClassInfo(cls)
223          state.size += alignSize(classInfo.shellSize)
224          for (field <- classInfo.pointerFields) {
225            state.enqueue(field.get(obj))
226          }
227      }
228    }
229  }
230
231  // Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling.
232  private val ARRAY_SIZE_FOR_SAMPLING = 400
233  private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
234
235  private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) {
236    val length = ScalaRunTime.array_length(array)
237    val elementClass = arrayClass.getComponentType()
238
239    // Arrays have object header and length field which is an integer
240    var arrSize: Long = alignSize(objectSize + INT_SIZE)
241
242    if (elementClass.isPrimitive) {
243      arrSize += alignSize(length.toLong * primitiveSize(elementClass))
244      state.size += arrSize
245    } else {
246      arrSize += alignSize(length.toLong * pointerSize)
247      state.size += arrSize
248
249      if (length <= ARRAY_SIZE_FOR_SAMPLING) {
250        var arrayIndex = 0
251        while (arrayIndex < length) {
252          state.enqueue(ScalaRunTime.array_apply(array, arrayIndex).asInstanceOf[AnyRef])
253          arrayIndex += 1
254        }
255      } else {
256        // Estimate the size of a large array by sampling elements without replacement.
257        // To exclude the shared objects that the array elements may link, sample twice
258        // and use the min one to calculate array size.
259        val rand = new Random(42)
260        val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE)
261        val s1 = sampleArray(array, state, rand, drawn, length)
262        val s2 = sampleArray(array, state, rand, drawn, length)
263        val size = math.min(s1, s2)
264        state.size += math.max(s1, s2) +
265          (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong
266      }
267    }
268  }
269
270  private def sampleArray(
271      array: AnyRef,
272      state: SearchState,
273      rand: Random,
274      drawn: OpenHashSet[Int],
275      length: Int): Long = {
276    var size = 0L
277    for (i <- 0 until ARRAY_SAMPLE_SIZE) {
278      var index = 0
279      do {
280        index = rand.nextInt(length)
281      } while (drawn.contains(index))
282      drawn.add(index)
283      val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef]
284      if (obj != null) {
285        size += SizeEstimator.estimate(obj, state.visited).toLong
286      }
287    }
288    size
289  }
290
291  private def primitiveSize(cls: Class[_]): Int = {
292    if (cls == classOf[Byte]) {
293      BYTE_SIZE
294    } else if (cls == classOf[Boolean]) {
295      BOOLEAN_SIZE
296    } else if (cls == classOf[Char]) {
297      CHAR_SIZE
298    } else if (cls == classOf[Short]) {
299      SHORT_SIZE
300    } else if (cls == classOf[Int]) {
301      INT_SIZE
302    } else if (cls == classOf[Long]) {
303      LONG_SIZE
304    } else if (cls == classOf[Float]) {
305      FLOAT_SIZE
306    } else if (cls == classOf[Double]) {
307      DOUBLE_SIZE
308    } else {
309      throw new IllegalArgumentException(
310      "Non-primitive class " + cls + " passed to primitiveSize()")
311    }
312  }
313
314  /**
315   * Get or compute the ClassInfo for a given class.
316   */
317  private def getClassInfo(cls: Class[_]): ClassInfo = {
318    // Check whether we've already cached a ClassInfo for this class
319    val info = classInfos.get(cls)
320    if (info != null) {
321      return info
322    }
323
324    val parent = getClassInfo(cls.getSuperclass)
325    var shellSize = parent.shellSize
326    var pointerFields = parent.pointerFields
327    val sizeCount = Array.fill(fieldSizes.max + 1)(0)
328
329    // iterate through the fields of this class and gather information.
330    for (field <- cls.getDeclaredFields) {
331      if (!Modifier.isStatic(field.getModifiers)) {
332        val fieldClass = field.getType
333        if (fieldClass.isPrimitive) {
334          sizeCount(primitiveSize(fieldClass)) += 1
335        } else {
336          field.setAccessible(true) // Enable future get()'s on this field
337          sizeCount(pointerSize) += 1
338          pointerFields = field :: pointerFields
339        }
340      }
341    }
342
343    // Based on the simulated field layout code in Aleksey Shipilev's report:
344    // http://cr.openjdk.java.net/~shade/papers/2013-shipilev-fieldlayout-latest.pdf
345    // The code is in Figure 9.
346    // The simplified idea of field layout consists of 4 parts (see more details in the report):
347    //
348    // 1. field alignment: HotSpot lays out the fields aligned by their size.
349    // 2. object alignment: HotSpot rounds instance size up to 8 bytes
350    // 3. consistent fields layouts throughout the hierarchy: This means we should layout
351    // superclass first. And we can use superclass's shellSize as a starting point to layout the
352    // other fields in this class.
353    // 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed
354    // with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322
355    //
356    // The real world field layout is much more complicated. There are three kinds of fields
357    // order in Java 8. And we don't consider the @contended annotation introduced by Java 8.
358    // see the HotSpot classloader code, layout_fields method for more details.
359    // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp
360    var alignedSize = shellSize
361    for (size <- fieldSizes if sizeCount(size) > 0) {
362      val count = sizeCount(size).toLong
363      // If there are internal gaps, smaller field can fit in.
364      alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count)
365      shellSize += size * count
366    }
367
368    // Should choose a larger size to be new shellSize and clearly alignedSize >= shellSize, and
369    // round up the instance filed blocks
370    shellSize = alignSizeUp(alignedSize, pointerSize)
371
372    // Create and cache a new ClassInfo
373    val newInfo = new ClassInfo(shellSize, pointerFields)
374    classInfos.put(cls, newInfo)
375    newInfo
376  }
377
378  private def alignSize(size: Long): Long = alignSizeUp(size, ALIGN_SIZE)
379
380  /**
381   * Compute aligned size. The alignSize must be 2^n, otherwise the result will be wrong.
382   * When alignSize = 2^n, alignSize - 1 = 2^n - 1. The binary representation of (alignSize - 1)
383   * will only have n trailing 1s(0b00...001..1). ~(alignSize - 1) will be 0b11..110..0. Hence,
384   * (size + alignSize - 1) & ~(alignSize - 1) will set the last n bits to zeros, which leads to
385   * multiple of alignSize.
386   */
387  private def alignSizeUp(size: Long, alignSize: Int): Long =
388    (size + alignSize - 1) & ~(alignSize - 1)
389}
390