1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet
19
20import java.util.HashSet
21
22import org.slf4j.LoggerFactory
23
24import scala.collection.mutable
25import scala.collection.mutable.ArrayBuffer
26import scala.util.Try
27import scala.util.control.{ControlThrowable, NonFatal}
28
29/**
30  * This class manages automatically releasing of `org.apache.mxnet.NativeResource`s
31  */
32class ResourceScope extends AutoCloseable {
33
34  // HashSet does not take a custom comparator
35  private[mxnet] val resourceQ = new mutable.TreeSet[NativeResource]()(nativeAddressOrdering)
36
37  private object nativeAddressOrdering extends Ordering[NativeResource] {
38    def compare(a: NativeResource, b: NativeResource): Int = {
39      a.nativeAddress compare  b.nativeAddress
40    }
41  }
42
43  ResourceScope.addToThreadLocal(this)
44
45  /**
46    * Releases all the `org.apache.mxnet.NativeResource` by calling
47    * the associated`'org.apache.mxnet.NativeResource.close()` method
48    */
49  override def close(): Unit = {
50    ResourceScope.removeFromThreadLocal(this)
51    if (!ResourceScope.threadLocalScopes.get().contains(this)) {
52      resourceQ.foreach(resource => if (resource != null) resource.dispose(false))
53      resourceQ.clear()
54    }
55  }
56
57  /**
58    * Add a NativeResource to the scope
59    * @param resource
60    */
61  def add(resource: NativeResource): Unit = {
62    resourceQ.+=(resource)
63    resource.scope = Some(this)
64  }
65
66  /**
67    * Check if a NativeResource is in the scope
68    * @param resource
69    */
70  def contains(resource: NativeResource): Boolean = {
71    resourceQ.contains(resource)
72  }
73
74  /**
75    * Remove NativeResource from the Scope, this uses
76    * object equality to find the resource in the stack.
77    * @param resource
78    */
79  def remove(resource: NativeResource): Unit = {
80    resourceQ.-=(resource)
81    resource.scope = None
82  }
83
84  /**
85    * Removes from current Scope and moves to outer scope if it exists
86    * @param resource Resource to be moved to an outer scope
87    */
88  def moveToOuterScope(resource: NativeResource): Unit = {
89    val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
90    if (prevScope.isDefined) {
91      if (contains(resource)) {
92        this.remove(resource)
93        prevScope.get.add(resource)
94      }
95    } else this.remove(resource)
96  }
97
98}
99
100object ResourceScope {
101
102  private val logger = LoggerFactory.getLogger(classOf[ResourceScope])
103
104  /**
105    * Captures all Native Resources created using the ResourceScope and
106    * at the end of the body, de allocates all the Native resources by calling close on them.
107    * This method will not deAllocate NativeResources returned from the block.
108    * @param scope (Optional). Scope in which to capture the native resources
109    * @param body  block of code to execute in this scope
110    * @tparam A return type
111    * @return result of the operation, if the result is of type NativeResource, it is not
112    *         de allocated so the user can use it and then de allocate manually by calling
113    *         close or enclose in another resourceScope.
114    */
115  // inspired from slide 21 of https://www.slideshare.net/Odersky/fosdem-2009-1013261
116  // and https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/Using.scala
117  // TODO: we should move to the Scala util's Using method when we move to Scala 2.13
118  def using[A](scope: ResourceScope = null)(body: => A): A = {
119
120    val curScope = if (scope != null) scope else new ResourceScope()
121
122    def recursiveMoveToOuterScope(resource: Any): Unit = {
123      resource match {
124        case nRes: NativeResource => curScope.moveToOuterScope(nRes)
125        case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) )
126        case resInGeneric: scala.collection.Traversable[_] =>
127          resInGeneric.foreach(recursiveMoveToOuterScope)
128        case resProduct: scala.Product =>
129          resProduct.productIterator.foreach(recursiveMoveToOuterScope)
130        case _ => // do nothing
131      }
132    }
133
134    @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
135      if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
136    }
137
138    var retThrowable: Throwable = null
139
140    try {
141      val ret = body
142      recursiveMoveToOuterScope(ret)
143      ret
144    } catch {
145      case t: Throwable =>
146        retThrowable = t
147        null.asInstanceOf[A] // we'll throw in finally
148    } finally {
149      var toThrow: Throwable = retThrowable
150      if (retThrowable eq null) curScope.close
151      else {
152        try {
153          curScope.close
154        } catch {
155          case closeThrowable: Throwable =>
156            if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable
157            else safeAddSuppressed(retThrowable, closeThrowable)
158        } finally {
159          throw toThrow
160        }
161      }
162    }
163  }
164
165  private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = {
166    if (scope == None) {
167      body
168    } else {
169      ResourceScope.addToThreadLocal(scope.get)
170      ResourceScope.using(scope.get){
171        body
172      }
173    }
174  }
175
176  // thread local Scopes
177  private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
178    override def initialValue(): ArrayBuffer[ResourceScope] =
179      new ArrayBuffer[ResourceScope]()
180  }
181
182  /**
183    * Add resource to current ThreadLocal DataStructure
184    * @param r ResourceScope to add.
185    */
186  private[mxnet] def addToThreadLocal(r: ResourceScope): Unit = {
187    threadLocalScopes.get() += r
188  }
189
190  /**
191    * Remove resource from current ThreadLocal DataStructure
192    * @param r ResourceScope to remove
193    */
194  private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
195    threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r))
196  }
197
198  /**
199    * Get the latest Scope in the stack
200    * @return
201    */
202  private[mxnet] def getCurrentScope(): Option[ResourceScope] = {
203    Try(Some(threadLocalScopes.get().last)).getOrElse(None)
204  }
205
206  /**
207    * Get the Last but one Scope from threadLocal Scopes.
208    * @return n-1th scope or None when not found
209    */
210  private[mxnet] def getPrevScope(): Option[ResourceScope] = {
211    val scopes = threadLocalScopes.get()
212    Try(Some(scopes(scopes.size - 2))).getOrElse(None)
213  }
214}
215