1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.util 19 20import java.io.{ByteArrayInputStream, ByteArrayOutputStream} 21 22import scala.collection.mutable.{Map, Set, Stack} 23import scala.language.existentials 24 25import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} 26import org.apache.xbean.asm5.Opcodes._ 27 28import org.apache.spark.{SparkEnv, SparkException} 29import org.apache.spark.internal.Logging 30 31/** 32 * A cleaner that renders closures serializable if they can be done so safely. 33 */ 34private[spark] object ClosureCleaner extends Logging { 35 36 // Get an ASM class reader for a given class from the JAR that loaded it 37 private[util] def getClassReader(cls: Class[_]): ClassReader = { 38 // Copy data over, before delegating to ClassReader - else we can run out of open file handles. 39 val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" 40 val resourceStream = cls.getResourceAsStream(className) 41 // todo: Fixme - continuing with earlier behavior ... 42 if (resourceStream == null) return new ClassReader(resourceStream) 43 44 val baos = new ByteArrayOutputStream(128) 45 Utils.copyStream(resourceStream, baos, true) 46 new ClassReader(new ByteArrayInputStream(baos.toByteArray)) 47 } 48 49 // Check whether a class represents a Scala closure 50 private def isClosure(cls: Class[_]): Boolean = { 51 cls.getName.contains("$anonfun$") 52 } 53 54 // Get a list of the outer objects and their classes of a given closure object, obj; 55 // the outer objects are defined as any closures that obj is nested within, plus 56 // possibly the class that the outermost closure is in, if any. We stop searching 57 // for outer objects beyond that because cloning the user's object is probably 58 // not a good idea (whereas we can clone closure objects just fine since we 59 // understand how all their fields are used). 60 private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = { 61 for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { 62 f.setAccessible(true) 63 val outer = f.get(obj) 64 // The outer pointer may be null if we have cleaned this closure before 65 if (outer != null) { 66 if (isClosure(f.getType)) { 67 val recurRet = getOuterClassesAndObjects(outer) 68 return (f.getType :: recurRet._1, outer :: recurRet._2) 69 } else { 70 return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure 71 } 72 } 73 } 74 (Nil, Nil) 75 } 76 /** 77 * Return a list of classes that represent closures enclosed in the given closure object. 78 */ 79 private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = { 80 val seen = Set[Class[_]](obj.getClass) 81 val stack = Stack[Class[_]](obj.getClass) 82 while (!stack.isEmpty) { 83 val cr = getClassReader(stack.pop()) 84 val set = Set[Class[_]]() 85 cr.accept(new InnerClosureFinder(set), 0) 86 for (cls <- set -- seen) { 87 seen += cls 88 stack.push(cls) 89 } 90 } 91 (seen - obj.getClass).toList 92 } 93 94 /** 95 * Clean the given closure in place. 96 * 97 * More specifically, this renders the given closure serializable as long as it does not 98 * explicitly reference unserializable objects. 99 * 100 * @param closure the closure to clean 101 * @param checkSerializable whether to verify that the closure is serializable after cleaning 102 * @param cleanTransitively whether to clean enclosing closures transitively 103 */ 104 def clean( 105 closure: AnyRef, 106 checkSerializable: Boolean = true, 107 cleanTransitively: Boolean = true): Unit = { 108 clean(closure, checkSerializable, cleanTransitively, Map.empty) 109 } 110 111 /** 112 * Helper method to clean the given closure in place. 113 * 114 * The mechanism is to traverse the hierarchy of enclosing closures and null out any 115 * references along the way that are not actually used by the starting closure, but are 116 * nevertheless included in the compiled anonymous classes. Note that it is unsafe to 117 * simply mutate the enclosing closures in place, as other code paths may depend on them. 118 * Instead, we clone each enclosing closure and set the parent pointers accordingly. 119 * 120 * By default, closures are cleaned transitively. This means we detect whether enclosing 121 * objects are actually referenced by the starting one, either directly or transitively, 122 * and, if not, sever these closures from the hierarchy. In other words, in addition to 123 * nulling out unused field references, we also null out any parent pointers that refer 124 * to enclosing objects not actually needed by the starting closure. We determine 125 * transitivity by tracing through the tree of all methods ultimately invoked by the 126 * inner closure and record all the fields referenced in the process. 127 * 128 * For instance, transitive cleaning is necessary in the following scenario: 129 * 130 * class SomethingNotSerializable { 131 * def someValue = 1 132 * def scope(name: String)(body: => Unit) = body 133 * def someMethod(): Unit = scope("one") { 134 * def x = someValue 135 * def y = 2 136 * scope("two") { println(y + 1) } 137 * } 138 * } 139 * 140 * In this example, scope "two" is not serializable because it references scope "one", which 141 * references SomethingNotSerializable. Note that, however, the body of scope "two" does not 142 * actually depend on SomethingNotSerializable. This means we can safely null out the parent 143 * pointer of a cloned scope "one" and set it the parent of scope "two", such that scope "two" 144 * no longer references SomethingNotSerializable transitively. 145 * 146 * @param func the starting closure to clean 147 * @param checkSerializable whether to verify that the closure is serializable after cleaning 148 * @param cleanTransitively whether to clean enclosing closures transitively 149 * @param accessedFields a map from a class to a set of its fields that are accessed by 150 * the starting closure 151 */ 152 private def clean( 153 func: AnyRef, 154 checkSerializable: Boolean, 155 cleanTransitively: Boolean, 156 accessedFields: Map[Class[_], Set[String]]): Unit = { 157 158 if (!isClosure(func.getClass)) { 159 logWarning("Expected a closure; got " + func.getClass.getName) 160 return 161 } 162 163 // TODO: clean all inner closures first. This requires us to find the inner objects. 164 // TODO: cache outerClasses / innerClasses / accessedFields 165 166 if (func == null) { 167 return 168 } 169 170 logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") 171 172 // A list of classes that represents closures enclosed in the given one 173 val innerClasses = getInnerClosureClasses(func) 174 175 // A list of enclosing objects and their respective classes, from innermost to outermost 176 // An outer object at a given index is of type outer class at the same index 177 val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) 178 179 // For logging purposes only 180 val declaredFields = func.getClass.getDeclaredFields 181 val declaredMethods = func.getClass.getDeclaredMethods 182 183 logDebug(" + declared fields: " + declaredFields.size) 184 declaredFields.foreach { f => logDebug(" " + f) } 185 logDebug(" + declared methods: " + declaredMethods.size) 186 declaredMethods.foreach { m => logDebug(" " + m) } 187 logDebug(" + inner classes: " + innerClasses.size) 188 innerClasses.foreach { c => logDebug(" " + c.getName) } 189 logDebug(" + outer classes: " + outerClasses.size) 190 outerClasses.foreach { c => logDebug(" " + c.getName) } 191 logDebug(" + outer objects: " + outerObjects.size) 192 outerObjects.foreach { o => logDebug(" " + o) } 193 194 // Fail fast if we detect return statements in closures 195 getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) 196 197 // If accessed fields is not populated yet, we assume that 198 // the closure we are trying to clean is the starting one 199 if (accessedFields.isEmpty) { 200 logDebug(s" + populating accessed fields because this is the starting closure") 201 // Initialize accessed fields with the outer classes first 202 // This step is needed to associate the fields to the correct classes later 203 for (cls <- outerClasses) { 204 accessedFields(cls) = Set[String]() 205 } 206 // Populate accessed fields by visiting all fields and methods accessed by this and 207 // all of its inner closures. If transitive cleaning is enabled, this may recursively 208 // visits methods that belong to other classes in search of transitively referenced fields. 209 for (cls <- func.getClass :: innerClasses) { 210 getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) 211 } 212 } 213 214 logDebug(s" + fields accessed by starting closure: " + accessedFields.size) 215 accessedFields.foreach { f => logDebug(" " + f) } 216 217 // List of outer (class, object) pairs, ordered from outermost to innermost 218 // Note that all outer objects but the outermost one (first one in this list) must be closures 219 var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse 220 var parent: AnyRef = null 221 if (outerPairs.size > 0) { 222 val (outermostClass, outermostObject) = outerPairs.head 223 if (isClosure(outermostClass)) { 224 logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") 225 } else if (outermostClass.getName.startsWith("$line")) { 226 // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it 227 // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. 228 logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") 229 } else { 230 // The closure is ultimately nested inside a class; keep the object of that 231 // class without cloning it since we don't want to clone the user's objects. 232 // Note that we still need to keep around the outermost object itself because 233 // we need it to clone its child closure later (see below). 234 logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + 235 outerPairs.head) 236 parent = outermostObject // e.g. SparkContext 237 outerPairs = outerPairs.tail 238 } 239 } else { 240 logDebug(" + there are no enclosing objects!") 241 } 242 243 // Clone the closure objects themselves, nulling out any fields that are not 244 // used in the closure we're working on or any of its inner closures. 245 for ((cls, obj) <- outerPairs) { 246 logDebug(s" + cloning the object $obj of class ${cls.getName}") 247 // We null out these unused references by cloning each object and then filling in all 248 // required fields from the original object. We need the parent here because the Java 249 // language specification requires the first constructor parameter of any closure to be 250 // its enclosing object. 251 val clone = instantiateClass(cls, parent) 252 for (fieldName <- accessedFields(cls)) { 253 val field = cls.getDeclaredField(fieldName) 254 field.setAccessible(true) 255 val value = field.get(obj) 256 field.set(clone, value) 257 } 258 // If transitive cleaning is enabled, we recursively clean any enclosing closure using 259 // the already populated accessed fields map of the starting closure 260 if (cleanTransitively && isClosure(clone.getClass)) { 261 logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") 262 // No need to check serializable here for the outer closures because we're 263 // only interested in the serializability of the starting closure 264 clean(clone, checkSerializable = false, cleanTransitively, accessedFields) 265 } 266 parent = clone 267 } 268 269 // Update the parent pointer ($outer) of this closure 270 if (parent != null) { 271 val field = func.getClass.getDeclaredField("$outer") 272 field.setAccessible(true) 273 // If the starting closure doesn't actually need our enclosing object, then just null it out 274 if (accessedFields.contains(func.getClass) && 275 !accessedFields(func.getClass).contains("$outer")) { 276 logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") 277 field.set(func, null) 278 } else { 279 // Update this closure's parent pointer to point to our enclosing object, 280 // which could either be a cloned closure or the original user object 281 field.set(func, parent) 282 } 283 } 284 285 logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") 286 287 if (checkSerializable) { 288 ensureSerializable(func) 289 } 290 } 291 292 private def ensureSerializable(func: AnyRef) { 293 try { 294 if (SparkEnv.get != null) { 295 SparkEnv.get.closureSerializer.newInstance().serialize(func) 296 } 297 } catch { 298 case ex: Exception => throw new SparkException("Task not serializable", ex) 299 } 300 } 301 302 private def instantiateClass( 303 cls: Class[_], 304 enclosingObject: AnyRef): AnyRef = { 305 // Use reflection to instantiate object without calling constructor 306 val rf = sun.reflect.ReflectionFactory.getReflectionFactory() 307 val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() 308 val newCtor = rf.newConstructorForSerialization(cls, parentCtor) 309 val obj = newCtor.newInstance().asInstanceOf[AnyRef] 310 if (enclosingObject != null) { 311 val field = cls.getDeclaredField("$outer") 312 field.setAccessible(true) 313 field.set(obj, enclosingObject) 314 } 315 obj 316 } 317} 318 319private[spark] class ReturnStatementInClosureException 320 extends SparkException("Return statements aren't allowed in Spark closures") 321 322private class ReturnStatementFinder extends ClassVisitor(ASM5) { 323 override def visitMethod(access: Int, name: String, desc: String, 324 sig: String, exceptions: Array[String]): MethodVisitor = { 325 if (name.contains("apply")) { 326 new MethodVisitor(ASM5) { 327 override def visitTypeInsn(op: Int, tp: String) { 328 if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { 329 throw new ReturnStatementInClosureException 330 } 331 } 332 } 333 } else { 334 new MethodVisitor(ASM5) {} 335 } 336 } 337} 338 339/** Helper class to identify a method. */ 340private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) 341 342/** 343 * Find the fields accessed by a given class. 344 * 345 * The resulting fields are stored in the mutable map passed in through the constructor. 346 * This map is assumed to have its keys already populated with the classes of interest. 347 * 348 * @param fields the mutable map that stores the fields to return 349 * @param findTransitively if true, find fields indirectly referenced through method calls 350 * @param specificMethod if not empty, visit only this specific method 351 * @param visitedMethods a set of visited methods to avoid cycles 352 */ 353private[util] class FieldAccessFinder( 354 fields: Map[Class[_], Set[String]], 355 findTransitively: Boolean, 356 specificMethod: Option[MethodIdentifier[_]] = None, 357 visitedMethods: Set[MethodIdentifier[_]] = Set.empty) 358 extends ClassVisitor(ASM5) { 359 360 override def visitMethod( 361 access: Int, 362 name: String, 363 desc: String, 364 sig: String, 365 exceptions: Array[String]): MethodVisitor = { 366 367 // If we are told to visit only a certain method and this is not the one, ignore it 368 if (specificMethod.isDefined && 369 (specificMethod.get.name != name || specificMethod.get.desc != desc)) { 370 return null 371 } 372 373 new MethodVisitor(ASM5) { 374 override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { 375 if (op == GETFIELD) { 376 for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { 377 fields(cl) += name 378 } 379 } 380 } 381 382 override def visitMethodInsn( 383 op: Int, owner: String, name: String, desc: String, itf: Boolean) { 384 for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { 385 // Check for calls a getter method for a variable in an interpreter wrapper object. 386 // This means that the corresponding field will be accessed, so we should save it. 387 if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { 388 fields(cl) += name 389 } 390 // Optionally visit other methods to find fields that are transitively referenced 391 if (findTransitively) { 392 val m = MethodIdentifier(cl, name, desc) 393 if (!visitedMethods.contains(m)) { 394 // Keep track of visited methods to avoid potential infinite cycles 395 visitedMethods += m 396 ClosureCleaner.getClassReader(cl).accept( 397 new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) 398 } 399 } 400 } 401 } 402 } 403 } 404} 405 406private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { 407 var myName: String = null 408 409 // TODO: Recursively find inner closures that we indirectly reference, e.g. 410 // val closure1 = () = { () => 1 } 411 // val closure2 = () => { (1 to 5).map(closure1) } 412 // The second closure technically has two inner closures, but this finder only finds one 413 414 override def visit(version: Int, access: Int, name: String, sig: String, 415 superName: String, interfaces: Array[String]) { 416 myName = name 417 } 418 419 override def visitMethod(access: Int, name: String, desc: String, 420 sig: String, exceptions: Array[String]): MethodVisitor = { 421 new MethodVisitor(ASM5) { 422 override def visitMethodInsn( 423 op: Int, owner: String, name: String, desc: String, itf: Boolean) { 424 val argTypes = Type.getArgumentTypes(desc) 425 if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0 426 && argTypes(0).toString.startsWith("L") // is it an object? 427 && argTypes(0).getInternalName == myName) { 428 // scalastyle:off classforname 429 output += Class.forName( 430 owner.replace('/', '.'), 431 false, 432 Thread.currentThread.getContextClassLoader) 433 // scalastyle:on classforname 434 } 435 } 436 } 437 } 438} 439