1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet
19
20import org.apache.mxnet.util.NativeLibraryLoader
21import org.slf4j.{Logger, LoggerFactory}
22
23import scala.Specializable.Group
24
25private[mxnet] object Base {
26  private val logger: Logger = LoggerFactory.getLogger("MXNetJVM")
27
28  // type definitions
29  class RefInt(val value: Int = 0)
30  class RefLong(val value: Long = 0)
31  class RefFloat(val value: Float = 0)
32  class RefString(val value: String = null)
33
34  type MXUint = Int
35  type MXFloat = Float
36  type CPtrAddress = Long
37
38  type NDArrayHandle = CPtrAddress
39  type FunctionHandle = CPtrAddress
40  type DataIterHandle = CPtrAddress
41  type DataIterCreator = CPtrAddress
42  type KVStoreHandle = CPtrAddress
43  type ExecutorHandle = CPtrAddress
44  type SymbolHandle = CPtrAddress
45  type RecordIOHandle = CPtrAddress
46  type RtcHandle = CPtrAddress
47
48  type MXUintRef = RefInt
49  type MXFloatRef = RefFloat
50  type NDArrayHandleRef = RefLong
51  type FunctionHandleRef = RefLong
52  type DataIterHandleRef = RefLong
53  type DataIterCreatorRef = RefLong
54  type KVStoreHandleRef = RefLong
55  type ExecutorHandleRef = RefLong
56  type SymbolHandleRef = RefLong
57  type RecordIOHandleRef = RefLong
58  type RtcHandleRef = RefLong
59
60  val MX_REAL_TYPE = DType.Float32
61
62  // The primitives currently supported for NDArray operations
63  val MX_PRIMITIVES = new Group ((Double, Float))
64
65
66  /* Find the native libray either on the path or copy it from
67   * the jar in the dependency
68   * jar into a temp directory and load it
69   */
70  try {
71    try {
72      tryLoadLibraryOS("mxnet-scala")
73    } catch {
74      case e: UnsatisfiedLinkError =>
75        logger.info("Copying and loading native library from the jar archive")
76        NativeLibraryLoader.loadLibrary("mxnet-scala")
77    }
78  } catch {
79    case e: UnsatisfiedLinkError =>
80      logger.error("Couldn't find native library mxnet-scala")
81      throw e
82  }
83
84  val _LIB = new LibInfo
85  checkCall(_LIB.nativeLibInit())
86
87  // TODO: shutdown hook won't work on Windows
88  Runtime.getRuntime.addShutdownHook(new Thread() {
89    override def run(): Unit = {
90      notifyShutdown()
91    }
92  })
93
94  @throws(classOf[UnsatisfiedLinkError])
95  private def tryLoadLibraryOS(libname: String): Unit = {
96    logger.info(s"Try loading $libname from native path.")
97    System.loadLibrary(libname)
98  }
99
100  // helper function definitions
101  /**
102   * Check the return value of C API call
103   *
104   * This function will raise exception when error occurs.
105   * Wrap every API call with this function
106   * @param ret return value from API calls
107   */
108  def checkCall(ret: Int): Unit = {
109    if (ret != 0) {
110      throw new MXNetError(_LIB.mxGetLastError())
111    }
112  }
113
114  // Notify MXNet about a shutdown
115  private def notifyShutdown(): Unit = {
116    checkCall(_LIB.mxNotifyShutdown())
117  }
118
119  // Convert ctypes returned doc string information into parameters docstring.
120  def ctypes2docstring(
121      argNames: Seq[String],
122      argTypes: Seq[String],
123      argDescs: Seq[String]): String = {
124
125    val params =
126      (argNames zip argTypes zip argDescs) map { case ((argName, argType), argDesc) =>
127        val desc = if (argDesc.isEmpty) "" else s"\n$argDesc"
128        s"$argName : $argType$desc"
129      }
130    s"Parameters\n----------\n${params.mkString("\n")}\n"
131  }
132}
133
134class MXNetError(val err: String) extends Exception(err)
135
136// Some type-classes to ease the work in Symbol.random and NDArray.random modules
137
138class SymbolOrScalar[T](val isScalar: Boolean)
139object SymbolOrScalar {
140  def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev
141  implicit object FloatWitness extends SymbolOrScalar[Float](true)
142  implicit object IntWitness extends SymbolOrScalar[Int](true)
143  implicit object SymbolWitness extends SymbolOrScalar[Symbol](false)
144}
145
146class NDArrayOrScalar[T](val isScalar: Boolean)
147object NDArrayOrScalar {
148  def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev
149  implicit object FloatWitness extends NDArrayOrScalar[Float](true)
150  implicit object IntWitness extends NDArrayOrScalar[Int](true)
151  implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false)
152}
153