1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet 19 20import java.util.HashSet 21 22import org.slf4j.LoggerFactory 23 24import scala.collection.mutable 25import scala.collection.mutable.ArrayBuffer 26import scala.util.Try 27import scala.util.control.{ControlThrowable, NonFatal} 28 29/** 30 * This class manages automatically releasing of `org.apache.mxnet.NativeResource`s 31 */ 32class ResourceScope extends AutoCloseable { 33 34 // HashSet does not take a custom comparator 35 private[mxnet] val resourceQ = new mutable.TreeSet[NativeResource]()(nativeAddressOrdering) 36 37 private object nativeAddressOrdering extends Ordering[NativeResource] { 38 def compare(a: NativeResource, b: NativeResource): Int = { 39 a.nativeAddress compare b.nativeAddress 40 } 41 } 42 43 ResourceScope.addToThreadLocal(this) 44 45 /** 46 * Releases all the `org.apache.mxnet.NativeResource` by calling 47 * the associated`'org.apache.mxnet.NativeResource.close()` method 48 */ 49 override def close(): Unit = { 50 ResourceScope.removeFromThreadLocal(this) 51 if (!ResourceScope.threadLocalScopes.get().contains(this)) { 52 resourceQ.foreach(resource => if (resource != null) resource.dispose(false)) 53 resourceQ.clear() 54 } 55 } 56 57 /** 58 * Add a NativeResource to the scope 59 * @param resource 60 */ 61 def add(resource: NativeResource): Unit = { 62 resourceQ.+=(resource) 63 resource.scope = Some(this) 64 } 65 66 /** 67 * Check if a NativeResource is in the scope 68 * @param resource 69 */ 70 def contains(resource: NativeResource): Boolean = { 71 resourceQ.contains(resource) 72 } 73 74 /** 75 * Remove NativeResource from the Scope, this uses 76 * object equality to find the resource in the stack. 77 * @param resource 78 */ 79 def remove(resource: NativeResource): Unit = { 80 resourceQ.-=(resource) 81 resource.scope = None 82 } 83 84 /** 85 * Removes from current Scope and moves to outer scope if it exists 86 * @param resource Resource to be moved to an outer scope 87 */ 88 def moveToOuterScope(resource: NativeResource): Unit = { 89 val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope() 90 if (prevScope.isDefined) { 91 if (contains(resource)) { 92 this.remove(resource) 93 prevScope.get.add(resource) 94 } 95 } else this.remove(resource) 96 } 97 98} 99 100object ResourceScope { 101 102 private val logger = LoggerFactory.getLogger(classOf[ResourceScope]) 103 104 /** 105 * Captures all Native Resources created using the ResourceScope and 106 * at the end of the body, de allocates all the Native resources by calling close on them. 107 * This method will not deAllocate NativeResources returned from the block. 108 * @param scope (Optional). Scope in which to capture the native resources 109 * @param body block of code to execute in this scope 110 * @tparam A return type 111 * @return result of the operation, if the result is of type NativeResource, it is not 112 * de allocated so the user can use it and then de allocate manually by calling 113 * close or enclose in another resourceScope. 114 */ 115 // inspired from slide 21 of https://www.slideshare.net/Odersky/fosdem-2009-1013261 116 // and https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/Using.scala 117 // TODO: we should move to the Scala util's Using method when we move to Scala 2.13 118 def using[A](scope: ResourceScope = null)(body: => A): A = { 119 120 val curScope = if (scope != null) scope else new ResourceScope() 121 122 def recursiveMoveToOuterScope(resource: Any): Unit = { 123 resource match { 124 case nRes: NativeResource => curScope.moveToOuterScope(nRes) 125 case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) ) 126 case resInGeneric: scala.collection.Traversable[_] => 127 resInGeneric.foreach(recursiveMoveToOuterScope) 128 case resProduct: scala.Product => 129 resProduct.productIterator.foreach(recursiveMoveToOuterScope) 130 case _ => // do nothing 131 } 132 } 133 134 @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = { 135 if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed) 136 } 137 138 var retThrowable: Throwable = null 139 140 try { 141 val ret = body 142 recursiveMoveToOuterScope(ret) 143 ret 144 } catch { 145 case t: Throwable => 146 retThrowable = t 147 null.asInstanceOf[A] // we'll throw in finally 148 } finally { 149 var toThrow: Throwable = retThrowable 150 if (retThrowable eq null) curScope.close 151 else { 152 try { 153 curScope.close 154 } catch { 155 case closeThrowable: Throwable => 156 if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable 157 else safeAddSuppressed(retThrowable, closeThrowable) 158 } finally { 159 throw toThrow 160 } 161 } 162 } 163 } 164 165 private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = { 166 if (scope == None) { 167 body 168 } else { 169 ResourceScope.addToThreadLocal(scope.get) 170 ResourceScope.using(scope.get){ 171 body 172 } 173 } 174 } 175 176 // thread local Scopes 177 private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] { 178 override def initialValue(): ArrayBuffer[ResourceScope] = 179 new ArrayBuffer[ResourceScope]() 180 } 181 182 /** 183 * Add resource to current ThreadLocal DataStructure 184 * @param r ResourceScope to add. 185 */ 186 private[mxnet] def addToThreadLocal(r: ResourceScope): Unit = { 187 threadLocalScopes.get() += r 188 } 189 190 /** 191 * Remove resource from current ThreadLocal DataStructure 192 * @param r ResourceScope to remove 193 */ 194 private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = { 195 threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r)) 196 } 197 198 /** 199 * Get the latest Scope in the stack 200 * @return 201 */ 202 private[mxnet] def getCurrentScope(): Option[ResourceScope] = { 203 Try(Some(threadLocalScopes.get().last)).getOrElse(None) 204 } 205 206 /** 207 * Get the Last but one Scope from threadLocal Scopes. 208 * @return n-1th scope or None when not found 209 */ 210 private[mxnet] def getPrevScope(): Option[ResourceScope] = { 211 val scopes = threadLocalScopes.get() 212 Try(Some(scopes(scopes.size - 2))).getOrElse(None) 213 } 214} 215