1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.util
19
20import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
21
22import scala.collection.mutable.{Map, Set, Stack}
23import scala.language.existentials
24
25import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type}
26import org.apache.xbean.asm5.Opcodes._
27
28import org.apache.spark.{SparkEnv, SparkException}
29import org.apache.spark.internal.Logging
30
31/**
32 * A cleaner that renders closures serializable if they can be done so safely.
33 */
34private[spark] object ClosureCleaner extends Logging {
35
36  // Get an ASM class reader for a given class from the JAR that loaded it
37  private[util] def getClassReader(cls: Class[_]): ClassReader = {
38    // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
39    val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
40    val resourceStream = cls.getResourceAsStream(className)
41    // todo: Fixme - continuing with earlier behavior ...
42    if (resourceStream == null) return new ClassReader(resourceStream)
43
44    val baos = new ByteArrayOutputStream(128)
45    Utils.copyStream(resourceStream, baos, true)
46    new ClassReader(new ByteArrayInputStream(baos.toByteArray))
47  }
48
49  // Check whether a class represents a Scala closure
50  private def isClosure(cls: Class[_]): Boolean = {
51    cls.getName.contains("$anonfun$")
52  }
53
54  // Get a list of the outer objects and their classes of a given closure object, obj;
55  // the outer objects are defined as any closures that obj is nested within, plus
56  // possibly the class that the outermost closure is in, if any. We stop searching
57  // for outer objects beyond that because cloning the user's object is probably
58  // not a good idea (whereas we can clone closure objects just fine since we
59  // understand how all their fields are used).
60  private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = {
61    for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
62      f.setAccessible(true)
63      val outer = f.get(obj)
64      // The outer pointer may be null if we have cleaned this closure before
65      if (outer != null) {
66        if (isClosure(f.getType)) {
67          val recurRet = getOuterClassesAndObjects(outer)
68          return (f.getType :: recurRet._1, outer :: recurRet._2)
69        } else {
70          return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure
71        }
72      }
73    }
74    (Nil, Nil)
75  }
76  /**
77   * Return a list of classes that represent closures enclosed in the given closure object.
78   */
79  private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = {
80    val seen = Set[Class[_]](obj.getClass)
81    val stack = Stack[Class[_]](obj.getClass)
82    while (!stack.isEmpty) {
83      val cr = getClassReader(stack.pop())
84      val set = Set[Class[_]]()
85      cr.accept(new InnerClosureFinder(set), 0)
86      for (cls <- set -- seen) {
87        seen += cls
88        stack.push(cls)
89      }
90    }
91    (seen - obj.getClass).toList
92  }
93
94  /**
95   * Clean the given closure in place.
96   *
97   * More specifically, this renders the given closure serializable as long as it does not
98   * explicitly reference unserializable objects.
99   *
100   * @param closure the closure to clean
101   * @param checkSerializable whether to verify that the closure is serializable after cleaning
102   * @param cleanTransitively whether to clean enclosing closures transitively
103   */
104  def clean(
105      closure: AnyRef,
106      checkSerializable: Boolean = true,
107      cleanTransitively: Boolean = true): Unit = {
108    clean(closure, checkSerializable, cleanTransitively, Map.empty)
109  }
110
111  /**
112   * Helper method to clean the given closure in place.
113   *
114   * The mechanism is to traverse the hierarchy of enclosing closures and null out any
115   * references along the way that are not actually used by the starting closure, but are
116   * nevertheless included in the compiled anonymous classes. Note that it is unsafe to
117   * simply mutate the enclosing closures in place, as other code paths may depend on them.
118   * Instead, we clone each enclosing closure and set the parent pointers accordingly.
119   *
120   * By default, closures are cleaned transitively. This means we detect whether enclosing
121   * objects are actually referenced by the starting one, either directly or transitively,
122   * and, if not, sever these closures from the hierarchy. In other words, in addition to
123   * nulling out unused field references, we also null out any parent pointers that refer
124   * to enclosing objects not actually needed by the starting closure. We determine
125   * transitivity by tracing through the tree of all methods ultimately invoked by the
126   * inner closure and record all the fields referenced in the process.
127   *
128   * For instance, transitive cleaning is necessary in the following scenario:
129   *
130   *   class SomethingNotSerializable {
131   *     def someValue = 1
132   *     def scope(name: String)(body: => Unit) = body
133   *     def someMethod(): Unit = scope("one") {
134   *       def x = someValue
135   *       def y = 2
136   *       scope("two") { println(y + 1) }
137   *     }
138   *   }
139   *
140   * In this example, scope "two" is not serializable because it references scope "one", which
141   * references SomethingNotSerializable. Note that, however, the body of scope "two" does not
142   * actually depend on SomethingNotSerializable. This means we can safely null out the parent
143   * pointer of a cloned scope "one" and set it the parent of scope "two", such that scope "two"
144   * no longer references SomethingNotSerializable transitively.
145   *
146   * @param func the starting closure to clean
147   * @param checkSerializable whether to verify that the closure is serializable after cleaning
148   * @param cleanTransitively whether to clean enclosing closures transitively
149   * @param accessedFields a map from a class to a set of its fields that are accessed by
150   *                       the starting closure
151   */
152  private def clean(
153      func: AnyRef,
154      checkSerializable: Boolean,
155      cleanTransitively: Boolean,
156      accessedFields: Map[Class[_], Set[String]]): Unit = {
157
158    if (!isClosure(func.getClass)) {
159      logWarning("Expected a closure; got " + func.getClass.getName)
160      return
161    }
162
163    // TODO: clean all inner closures first. This requires us to find the inner objects.
164    // TODO: cache outerClasses / innerClasses / accessedFields
165
166    if (func == null) {
167      return
168    }
169
170    logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++")
171
172    // A list of classes that represents closures enclosed in the given one
173    val innerClasses = getInnerClosureClasses(func)
174
175    // A list of enclosing objects and their respective classes, from innermost to outermost
176    // An outer object at a given index is of type outer class at the same index
177    val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
178
179    // For logging purposes only
180    val declaredFields = func.getClass.getDeclaredFields
181    val declaredMethods = func.getClass.getDeclaredMethods
182
183    logDebug(" + declared fields: " + declaredFields.size)
184    declaredFields.foreach { f => logDebug("     " + f) }
185    logDebug(" + declared methods: " + declaredMethods.size)
186    declaredMethods.foreach { m => logDebug("     " + m) }
187    logDebug(" + inner classes: " + innerClasses.size)
188    innerClasses.foreach { c => logDebug("     " + c.getName) }
189    logDebug(" + outer classes: " + outerClasses.size)
190    outerClasses.foreach { c => logDebug("     " + c.getName) }
191    logDebug(" + outer objects: " + outerObjects.size)
192    outerObjects.foreach { o => logDebug("     " + o) }
193
194    // Fail fast if we detect return statements in closures
195    getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
196
197    // If accessed fields is not populated yet, we assume that
198    // the closure we are trying to clean is the starting one
199    if (accessedFields.isEmpty) {
200      logDebug(s" + populating accessed fields because this is the starting closure")
201      // Initialize accessed fields with the outer classes first
202      // This step is needed to associate the fields to the correct classes later
203      for (cls <- outerClasses) {
204        accessedFields(cls) = Set[String]()
205      }
206      // Populate accessed fields by visiting all fields and methods accessed by this and
207      // all of its inner closures. If transitive cleaning is enabled, this may recursively
208      // visits methods that belong to other classes in search of transitively referenced fields.
209      for (cls <- func.getClass :: innerClasses) {
210        getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0)
211      }
212    }
213
214    logDebug(s" + fields accessed by starting closure: " + accessedFields.size)
215    accessedFields.foreach { f => logDebug("     " + f) }
216
217    // List of outer (class, object) pairs, ordered from outermost to innermost
218    // Note that all outer objects but the outermost one (first one in this list) must be closures
219    var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
220    var parent: AnyRef = null
221    if (outerPairs.size > 0) {
222      val (outermostClass, outermostObject) = outerPairs.head
223      if (isClosure(outermostClass)) {
224        logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}")
225      } else if (outermostClass.getName.startsWith("$line")) {
226        // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it
227        // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc.
228        logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}")
229      } else {
230        // The closure is ultimately nested inside a class; keep the object of that
231        // class without cloning it since we don't want to clone the user's objects.
232        // Note that we still need to keep around the outermost object itself because
233        // we need it to clone its child closure later (see below).
234        logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " +
235          outerPairs.head)
236        parent = outermostObject // e.g. SparkContext
237        outerPairs = outerPairs.tail
238      }
239    } else {
240      logDebug(" + there are no enclosing objects!")
241    }
242
243    // Clone the closure objects themselves, nulling out any fields that are not
244    // used in the closure we're working on or any of its inner closures.
245    for ((cls, obj) <- outerPairs) {
246      logDebug(s" + cloning the object $obj of class ${cls.getName}")
247      // We null out these unused references by cloning each object and then filling in all
248      // required fields from the original object. We need the parent here because the Java
249      // language specification requires the first constructor parameter of any closure to be
250      // its enclosing object.
251      val clone = instantiateClass(cls, parent)
252      for (fieldName <- accessedFields(cls)) {
253        val field = cls.getDeclaredField(fieldName)
254        field.setAccessible(true)
255        val value = field.get(obj)
256        field.set(clone, value)
257      }
258      // If transitive cleaning is enabled, we recursively clean any enclosing closure using
259      // the already populated accessed fields map of the starting closure
260      if (cleanTransitively && isClosure(clone.getClass)) {
261        logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})")
262        // No need to check serializable here for the outer closures because we're
263        // only interested in the serializability of the starting closure
264        clean(clone, checkSerializable = false, cleanTransitively, accessedFields)
265      }
266      parent = clone
267    }
268
269    // Update the parent pointer ($outer) of this closure
270    if (parent != null) {
271      val field = func.getClass.getDeclaredField("$outer")
272      field.setAccessible(true)
273      // If the starting closure doesn't actually need our enclosing object, then just null it out
274      if (accessedFields.contains(func.getClass) &&
275        !accessedFields(func.getClass).contains("$outer")) {
276        logDebug(s" + the starting closure doesn't actually need $parent, so we null it out")
277        field.set(func, null)
278      } else {
279        // Update this closure's parent pointer to point to our enclosing object,
280        // which could either be a cloned closure or the original user object
281        field.set(func, parent)
282      }
283    }
284
285    logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++")
286
287    if (checkSerializable) {
288      ensureSerializable(func)
289    }
290  }
291
292  private def ensureSerializable(func: AnyRef) {
293    try {
294      if (SparkEnv.get != null) {
295        SparkEnv.get.closureSerializer.newInstance().serialize(func)
296      }
297    } catch {
298      case ex: Exception => throw new SparkException("Task not serializable", ex)
299    }
300  }
301
302  private def instantiateClass(
303      cls: Class[_],
304      enclosingObject: AnyRef): AnyRef = {
305    // Use reflection to instantiate object without calling constructor
306    val rf = sun.reflect.ReflectionFactory.getReflectionFactory()
307    val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
308    val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
309    val obj = newCtor.newInstance().asInstanceOf[AnyRef]
310    if (enclosingObject != null) {
311      val field = cls.getDeclaredField("$outer")
312      field.setAccessible(true)
313      field.set(obj, enclosingObject)
314    }
315    obj
316  }
317}
318
319private[spark] class ReturnStatementInClosureException
320  extends SparkException("Return statements aren't allowed in Spark closures")
321
322private class ReturnStatementFinder extends ClassVisitor(ASM5) {
323  override def visitMethod(access: Int, name: String, desc: String,
324      sig: String, exceptions: Array[String]): MethodVisitor = {
325    if (name.contains("apply")) {
326      new MethodVisitor(ASM5) {
327        override def visitTypeInsn(op: Int, tp: String) {
328          if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
329            throw new ReturnStatementInClosureException
330          }
331        }
332      }
333    } else {
334      new MethodVisitor(ASM5) {}
335    }
336  }
337}
338
339/** Helper class to identify a method. */
340private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String)
341
342/**
343 * Find the fields accessed by a given class.
344 *
345 * The resulting fields are stored in the mutable map passed in through the constructor.
346 * This map is assumed to have its keys already populated with the classes of interest.
347 *
348 * @param fields the mutable map that stores the fields to return
349 * @param findTransitively if true, find fields indirectly referenced through method calls
350 * @param specificMethod if not empty, visit only this specific method
351 * @param visitedMethods a set of visited methods to avoid cycles
352 */
353private[util] class FieldAccessFinder(
354    fields: Map[Class[_], Set[String]],
355    findTransitively: Boolean,
356    specificMethod: Option[MethodIdentifier[_]] = None,
357    visitedMethods: Set[MethodIdentifier[_]] = Set.empty)
358  extends ClassVisitor(ASM5) {
359
360  override def visitMethod(
361      access: Int,
362      name: String,
363      desc: String,
364      sig: String,
365      exceptions: Array[String]): MethodVisitor = {
366
367    // If we are told to visit only a certain method and this is not the one, ignore it
368    if (specificMethod.isDefined &&
369        (specificMethod.get.name != name || specificMethod.get.desc != desc)) {
370      return null
371    }
372
373    new MethodVisitor(ASM5) {
374      override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
375        if (op == GETFIELD) {
376          for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
377            fields(cl) += name
378          }
379        }
380      }
381
382      override def visitMethodInsn(
383          op: Int, owner: String, name: String, desc: String, itf: Boolean) {
384        for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
385          // Check for calls a getter method for a variable in an interpreter wrapper object.
386          // This means that the corresponding field will be accessed, so we should save it.
387          if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
388            fields(cl) += name
389          }
390          // Optionally visit other methods to find fields that are transitively referenced
391          if (findTransitively) {
392            val m = MethodIdentifier(cl, name, desc)
393            if (!visitedMethods.contains(m)) {
394              // Keep track of visited methods to avoid potential infinite cycles
395              visitedMethods += m
396              ClosureCleaner.getClassReader(cl).accept(
397                new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
398            }
399          }
400        }
401      }
402    }
403  }
404}
405
406private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) {
407  var myName: String = null
408
409  // TODO: Recursively find inner closures that we indirectly reference, e.g.
410  //   val closure1 = () = { () => 1 }
411  //   val closure2 = () => { (1 to 5).map(closure1) }
412  // The second closure technically has two inner closures, but this finder only finds one
413
414  override def visit(version: Int, access: Int, name: String, sig: String,
415      superName: String, interfaces: Array[String]) {
416    myName = name
417  }
418
419  override def visitMethod(access: Int, name: String, desc: String,
420      sig: String, exceptions: Array[String]): MethodVisitor = {
421    new MethodVisitor(ASM5) {
422      override def visitMethodInsn(
423          op: Int, owner: String, name: String, desc: String, itf: Boolean) {
424        val argTypes = Type.getArgumentTypes(desc)
425        if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
426            && argTypes(0).toString.startsWith("L") // is it an object?
427            && argTypes(0).getInternalName == myName) {
428          // scalastyle:off classforname
429          output += Class.forName(
430              owner.replace('/', '.'),
431              false,
432              Thread.currentThread.getContextClassLoader)
433          // scalastyle:on classforname
434        }
435      }
436    }
437  }
438}
439