1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet 19 20import org.apache.mxnet.util.NativeLibraryLoader 21import org.slf4j.{Logger, LoggerFactory} 22 23import scala.Specializable.Group 24 25private[mxnet] object Base { 26 private val logger: Logger = LoggerFactory.getLogger("MXNetJVM") 27 28 // type definitions 29 class RefInt(val value: Int = 0) 30 class RefLong(val value: Long = 0) 31 class RefFloat(val value: Float = 0) 32 class RefString(val value: String = null) 33 34 type MXUint = Int 35 type MXFloat = Float 36 type CPtrAddress = Long 37 38 type NDArrayHandle = CPtrAddress 39 type FunctionHandle = CPtrAddress 40 type DataIterHandle = CPtrAddress 41 type DataIterCreator = CPtrAddress 42 type KVStoreHandle = CPtrAddress 43 type ExecutorHandle = CPtrAddress 44 type SymbolHandle = CPtrAddress 45 type RecordIOHandle = CPtrAddress 46 type RtcHandle = CPtrAddress 47 48 type MXUintRef = RefInt 49 type MXFloatRef = RefFloat 50 type NDArrayHandleRef = RefLong 51 type FunctionHandleRef = RefLong 52 type DataIterHandleRef = RefLong 53 type DataIterCreatorRef = RefLong 54 type KVStoreHandleRef = RefLong 55 type ExecutorHandleRef = RefLong 56 type SymbolHandleRef = RefLong 57 type RecordIOHandleRef = RefLong 58 type RtcHandleRef = RefLong 59 60 val MX_REAL_TYPE = DType.Float32 61 62 // The primitives currently supported for NDArray operations 63 val MX_PRIMITIVES = new Group ((Double, Float)) 64 65 66 /* Find the native libray either on the path or copy it from 67 * the jar in the dependency 68 * jar into a temp directory and load it 69 */ 70 try { 71 try { 72 tryLoadLibraryOS("mxnet-scala") 73 } catch { 74 case e: UnsatisfiedLinkError => 75 logger.info("Copying and loading native library from the jar archive") 76 NativeLibraryLoader.loadLibrary("mxnet-scala") 77 } 78 } catch { 79 case e: UnsatisfiedLinkError => 80 logger.error("Couldn't find native library mxnet-scala") 81 throw e 82 } 83 84 val _LIB = new LibInfo 85 checkCall(_LIB.nativeLibInit()) 86 87 // TODO: shutdown hook won't work on Windows 88 Runtime.getRuntime.addShutdownHook(new Thread() { 89 override def run(): Unit = { 90 notifyShutdown() 91 } 92 }) 93 94 @throws(classOf[UnsatisfiedLinkError]) 95 private def tryLoadLibraryOS(libname: String): Unit = { 96 logger.info(s"Try loading $libname from native path.") 97 System.loadLibrary(libname) 98 } 99 100 // helper function definitions 101 /** 102 * Check the return value of C API call 103 * 104 * This function will raise exception when error occurs. 105 * Wrap every API call with this function 106 * @param ret return value from API calls 107 */ 108 def checkCall(ret: Int): Unit = { 109 if (ret != 0) { 110 throw new MXNetError(_LIB.mxGetLastError()) 111 } 112 } 113 114 // Notify MXNet about a shutdown 115 private def notifyShutdown(): Unit = { 116 checkCall(_LIB.mxNotifyShutdown()) 117 } 118 119 // Convert ctypes returned doc string information into parameters docstring. 120 def ctypes2docstring( 121 argNames: Seq[String], 122 argTypes: Seq[String], 123 argDescs: Seq[String]): String = { 124 125 val params = 126 (argNames zip argTypes zip argDescs) map { case ((argName, argType), argDesc) => 127 val desc = if (argDesc.isEmpty) "" else s"\n$argDesc" 128 s"$argName : $argType$desc" 129 } 130 s"Parameters\n----------\n${params.mkString("\n")}\n" 131 } 132} 133 134class MXNetError(val err: String) extends Exception(err) 135 136// Some type-classes to ease the work in Symbol.random and NDArray.random modules 137 138class SymbolOrScalar[T](val isScalar: Boolean) 139object SymbolOrScalar { 140 def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev 141 implicit object FloatWitness extends SymbolOrScalar[Float](true) 142 implicit object IntWitness extends SymbolOrScalar[Int](true) 143 implicit object SymbolWitness extends SymbolOrScalar[Symbol](false) 144} 145 146class NDArrayOrScalar[T](val isScalar: Boolean) 147object NDArrayOrScalar { 148 def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev 149 implicit object FloatWitness extends NDArrayOrScalar[Float](true) 150 implicit object IntWitness extends NDArrayOrScalar[Int](true) 151 implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false) 152} 153