1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file org_apache_mxnet_native_c_api.cc
22  * \brief JNI function implementations
23  */
24 #include "org_apache_mxnet_native_c_api.h"  // generated by javah
25 #include <nnvm/c_api.h>
26 #include <mxnet/c_api.h>
27 #include <dmlc/logging.h>
28 #include <mxnet/ndarray.h>
29 #include <../src/common/cuda_utils.h>
30 #include <mutex>
31 #include <iostream>
32 #include <functional>
33 #include <string>
34 #include <unordered_map>
35 #include <vector>
36 #include "jni_helper_func.h"
37 
38 JavaVM *_jvm;
39 
Java_org_apache_mxnet_LibInfo_nativeLibInit(JNIEnv * env,jobject obj)40 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_nativeLibInit
41   (JNIEnv *env, jobject obj) {
42   return env->GetJavaVM(&_jvm);
43 }
44 
Java_org_apache_mxnet_LibInfo_mxListAllOpNames(JNIEnv * env,jobject obj,jobject nameList)45 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListAllOpNames
46   (JNIEnv *env, jobject obj, jobject nameList) {
47   mx_uint outSize;
48   const char **outArray;
49   int ret = MXListAllOpNames(&outSize, &outArray);
50 
51   jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer");
52   jmethodID listAppend = env->GetMethodID(listCls,
53     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
54   for (size_t i = 0; i < outSize; ++i) {
55     env->CallObjectMethod(nameList, listAppend, env->NewStringUTF(outArray[i]));
56   }
57 
58   return ret;
59 }
60 
Java_org_apache_mxnet_LibInfo_nnGetOpHandle(JNIEnv * env,jobject obj,jstring jopname,jobject jhandle)61 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_nnGetOpHandle
62   (JNIEnv *env, jobject obj, jstring jopname, jobject jhandle) {
63   OpHandle handle;
64   const char *opname = env->GetStringUTFChars(jopname, 0);
65   int ret = NNGetOpHandle(opname, &handle);
66   env->ReleaseStringUTFChars(jopname, opname);
67 
68   jclass refClass = env->FindClass("org/apache/mxnet/Base$RefLong");
69   jfieldID refFid = env->GetFieldID(refClass, "value", "J");
70   env->SetLongField(jhandle, refFid, reinterpret_cast<jlong>(handle));
71 
72   return ret;
73 }
74 
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateNone(JNIEnv * env,jobject obj,jobject ndArrayHandle)75 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateNone
76   (JNIEnv *env, jobject obj, jobject ndArrayHandle) {
77   NDArrayHandle out;
78   int ret = MXNDArrayCreateNone(&out);
79   SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
80   return ret;
81 }
82 
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx(JNIEnv * env,jobject obj,jintArray shape,jint ndim,jint devType,jint devId,jint delayAlloc,jint dtype,jobject ndArrayHandle)83 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx
84   (JNIEnv *env, jobject obj, jintArray shape, jint ndim, jint devType,
85     jint devId, jint delayAlloc, jint dtype, jobject ndArrayHandle) {
86   jint *shapeArr = env->GetIntArrayElements(shape, NULL);
87   NDArrayHandle out;
88   int ret = MXNDArrayCreateEx(reinterpret_cast<mx_uint *>(shapeArr), static_cast<mx_uint>(ndim),
89                               devType, devId, delayAlloc, dtype, &out);
90   env->ReleaseIntArrayElements(shape, shapeArr, 0);
91   SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
92   return ret;
93 }
94 
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx(JNIEnv * env,jobject obj,jint storageType,jintArray shape,jint ndim,jint devType,jint devId,jint delayAlloc,jint dtype,jint numAux,jintArray auxTypes,jintArray auxNdims,jintArray auxShapes,jobject ndArrayHandle)95 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx
96   (JNIEnv *env, jobject obj, jint storageType, jintArray shape, jint ndim, jint devType,
97     jint devId, jint delayAlloc, jint dtype, jint numAux, jintArray auxTypes,
98     jintArray auxNdims, jintArray auxShapes, jobject ndArrayHandle) {
99     jint *shapeArr = env->GetIntArrayElements(shape, NULL);
100     jint *auxTypesArr = env->GetIntArrayElements(auxTypes, NULL);
101     jint *auxNdimsArr = env->GetIntArrayElements(auxNdims, NULL);
102     jint *auxShapesArr = env->GetIntArrayElements(auxShapes, NULL);
103     NDArrayHandle out;
104     int ret = MXNDArrayCreateSparseEx(storageType,
105      reinterpret_cast<const mx_uint *>(shapeArr),
106      static_cast<mx_uint>(ndim),
107      devType, devId, delayAlloc, dtype,
108      static_cast<mx_uint>(numAux),
109      reinterpret_cast<int *>(auxTypesArr),
110      reinterpret_cast<mx_uint *>(auxNdimsArr),
111      reinterpret_cast<const mx_uint *>(auxShapesArr),  &out);
112     env->ReleaseIntArrayElements(shape, shapeArr, 0);
113     env->ReleaseIntArrayElements(auxTypes, auxTypesArr, 0);
114     env->ReleaseIntArrayElements(auxNdims, auxNdimsArr, 0);
115     env->ReleaseIntArrayElements(auxShapes, auxShapesArr, 0);
116     SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
117     return ret;
118 }
119 
Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv * env,jobject obj)120 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv *env, jobject obj) {
121   return MXNDArrayWaitAll();
122 }
123 
Java_org_apache_mxnet_LibInfo_mxNDArrayWaitToRead(JNIEnv * env,jobject obj,jlong arrayPtr)124 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitToRead
125   (JNIEnv *env, jobject obj, jlong arrayPtr) {
126   return MXNDArrayWaitToRead(reinterpret_cast<NDArrayHandle>(arrayPtr));
127 }
128 
Java_org_apache_mxnet_LibInfo_mxListFunctions(JNIEnv * env,jobject obj,jobject functions)129 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListFunctions
130   (JNIEnv *env, jobject obj, jobject functions) {
131   jclass longCls = env->FindClass("java/lang/Long");
132   jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
133 
134   // scala.collection.mutable.ListBuffer append method
135   jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
136   jmethodID listAppend = env->GetMethodID(listClass,
137     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
138 
139   // Get function list
140   FunctionHandle *outArray;
141   mx_uint outSize;
142   int ret = MXListFunctions(&outSize, &outArray);
143   for (size_t i = 0; i < outSize; ++i) {
144     env->CallObjectMethod(functions, listAppend,
145                           env->NewObject(longCls, longConst, outArray[i]));
146   }
147   return ret;
148 }
149 
Java_org_apache_mxnet_LibInfo_mxFuncDescribe(JNIEnv * env,jobject obj,jlong funcPtr,jobject nUsedVars,jobject nScalars,jobject nMutateVars,jobject typeMask)150 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncDescribe
151   (JNIEnv *env, jobject obj, jlong funcPtr, jobject nUsedVars,
152     jobject nScalars, jobject nMutateVars, jobject typeMask) {
153   mx_uint numUseVars;
154   mx_uint numScalars;
155   mx_uint numMutateVars;
156   int type;
157   int ret = MXFuncDescribe(reinterpret_cast<FunctionHandle>(funcPtr), &numUseVars,
158                             &numScalars, &numMutateVars, &type);
159 
160   jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
161   jfieldID value = env->GetFieldID(refIntClass, "value", "I");
162   env->SetIntField(nUsedVars, value, static_cast<jint>(numUseVars));
163   env->SetIntField(nScalars, value, static_cast<jint>(numScalars));
164   env->SetIntField(nMutateVars, value, static_cast<jint>(numMutateVars));
165   env->SetIntField(typeMask, value, static_cast<jint>(type));
166 
167   return ret;
168 }
169 
Java_org_apache_mxnet_LibInfo_mxFuncGetInfo(JNIEnv * env,jobject obj,jlong funcPtr,jobject name,jobject desc,jobject numArgs,jobject argNames,jobject argTypes,jobject argDescs)170 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncGetInfo
171   (JNIEnv *env, jobject obj, jlong funcPtr, jobject name, jobject desc,
172     jobject numArgs, jobject argNames, jobject argTypes, jobject argDescs) {
173   const char *cName;
174   const char *cDesc;
175   mx_uint cNumArgs;
176   const char **cArgNames;
177   const char **cArgTypes;
178   const char **cArgDescs;
179   int ret = MXFuncGetInfo(reinterpret_cast<FunctionHandle>(funcPtr),
180                           &cName, &cDesc, &cNumArgs,
181                           &cArgNames, &cArgTypes, &cArgDescs);
182 
183   jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
184   jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
185 
186   jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
187   jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
188 
189   // scala.collection.mutable.ListBuffer append method
190   jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
191   jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
192       "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
193 
194   env->SetObjectField(name, valueStr, env->NewStringUTF(cName));
195   env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc));
196   env->SetIntField(numArgs, valueInt, static_cast<jint>(cNumArgs));
197   for (size_t i = 0; i < cNumArgs; ++i) {
198     env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i]));
199     env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i]));
200     env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i]));
201   }
202 
203   return ret;
204 }
205 
Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx(JNIEnv * env,jobject obj,jlong funcPtr,jlongArray inputs,jlongArray outputsGiven,jobject outputs,jint numParams,jobjectArray paramKeys,jobjectArray paramVals,jobject outStypes)206 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx
207   (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray inputs,
208     jlongArray outputsGiven, jobject outputs, jint numParams,
209     jobjectArray paramKeys, jobjectArray paramVals, jobject outStypes) {
210 
211   const char **cParamKeys = NULL;
212   const char **cParamVals = NULL;
213   if (numParams > 0) {
214     cParamKeys = new const char *[numParams];
215     cParamVals = new const char *[numParams];
216     for (int i = 0; i < numParams; i++) {
217       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
218       const char *key = env->GetStringUTFChars(jkey, 0);
219       cParamKeys[i] = key;
220       env->DeleteLocalRef(jkey);
221       jstring jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
222       const char *val = env->GetStringUTFChars(jval, 0);
223       cParamVals[i] = val;
224       env->DeleteLocalRef(jval);
225     }
226   }
227 
228   int numOutputs = 0;
229   jlong *cOutputsGiven = NULL;
230   NDArrayHandle *cOutputs = NULL;
231   const int *cOutStypes;
232   if (outputsGiven) {
233     cOutputsGiven = env->GetLongArrayElements(outputsGiven, NULL);
234     cOutputs = reinterpret_cast<NDArrayHandle *>(cOutputsGiven);
235     numOutputs = static_cast<int>(env->GetArrayLength(outputsGiven));
236   }
237   jlong *cInputs = env->GetLongArrayElements(inputs, NULL);
238   jsize numInputs = env->GetArrayLength(inputs);
239   int ret = MXImperativeInvokeEx(reinterpret_cast<AtomicSymbolCreator>(funcPtr),
240                                static_cast<int>(numInputs),
241                                reinterpret_cast<NDArrayHandle *>(cInputs),
242                                &numOutputs,
243                                &cOutputs,
244                                static_cast<int>(numParams),
245                                cParamKeys,
246                                cParamVals,
247                                &cOutStypes);
248   env->ReleaseLongArrayElements(inputs, cInputs, 0);
249 
250   // release allocated memory
251   if (numParams > 0) {
252     for (int i = 0; i < numParams; i++) {
253       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
254       env->ReleaseStringUTFChars(jkey, cParamKeys[i]);
255       env->DeleteLocalRef(jkey);
256       jstring jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
257       env->ReleaseStringUTFChars(jval, cParamVals[i]);
258       env->DeleteLocalRef(jval);
259     }
260     delete[] cParamKeys;
261     delete[] cParamVals;
262   }
263 
264   if (cOutputs) {
265     jclass longCls = env->FindClass("java/lang/Long");
266     jclass intCls = env->FindClass("java/lang/Integer");
267     jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
268     jmethodID intConst = env->GetMethodID(intCls, "<init>", "(I)V");
269     // scala.collection.mutable.ListBuffer append method
270     jclass listClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
271     jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
272         "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
273     for (int i = 0; i < numOutputs; ++i) {
274       env->CallObjectMethod(outputs, listAppend,
275                             env->NewObject(longCls, longConst,
276                             reinterpret_cast<uint64_t>(cOutputs[i])));
277       env->CallObjectMethod(outStypes, listAppend,
278                             env->NewObject(intCls, intConst,
279                             cOutStypes[i]));
280     }
281   }
282 
283   if (cOutputsGiven) {
284     env->ReleaseLongArrayElements(outputsGiven, cOutputsGiven, 0);
285   }
286 
287   return ret;
288 }
289 
Java_org_apache_mxnet_LibInfo_mxFuncInvoke(JNIEnv * env,jobject obj,jlong funcPtr,jlongArray useVars,jfloatArray scalarArgs,jlongArray mutateVars)290 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncInvoke
291   (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray useVars,
292     jfloatArray scalarArgs, jlongArray mutateVars) {
293   jlong *cUseVars = env->GetLongArrayElements(useVars, NULL);
294   jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL);
295   jlong *cMutateVars = env->GetLongArrayElements(mutateVars, NULL);
296   int ret = MXFuncInvoke(reinterpret_cast<FunctionHandle>(funcPtr),
297                          reinterpret_cast<NDArrayHandle *>(cUseVars),
298                          reinterpret_cast<mx_float *>(cScalarArgs),
299                          reinterpret_cast<NDArrayHandle *>(cMutateVars));
300   env->ReleaseLongArrayElements(useVars, cUseVars, 0);
301   env->ReleaseFloatArrayElements(scalarArgs, cScalarArgs, 0);
302   env->ReleaseLongArrayElements(mutateVars, cMutateVars, 0);
303   return ret;
304 }
305 
Java_org_apache_mxnet_LibInfo_mxFuncInvokeEx(JNIEnv * env,jobject obj,jlong funcPtr,jlongArray useVars,jfloatArray scalarArgs,jlongArray mutateVars,jint numParams,jobjectArray paramKeys,jobjectArray paramVals)306 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncInvokeEx
307   (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray useVars,
308     jfloatArray scalarArgs, jlongArray mutateVars,
309     jint numParams, jobjectArray paramKeys, jobjectArray paramVals) {
310   jlong *cUseVars = env->GetLongArrayElements(useVars, NULL);
311   jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL);
312   jlong *cMutateVars = env->GetLongArrayElements(mutateVars, NULL);
313   jbyte **cParamKeys = NULL;
314   jbyte **cParamVals = NULL;
315   if (numParams > 0) {
316     cParamKeys = new jbyte *[numParams];
317     cParamVals = new jbyte *[numParams];
318     for (int i = 0; i < numParams; i++) {
319       jbyteArray jkey = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramKeys, i));
320       jbyte *cParamKey = env->GetByteArrayElements(jkey, NULL);
321       cParamKeys[i] = cParamKey;
322       env->DeleteLocalRef(jkey);
323       jbyteArray jval = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramVals, i));
324       jbyte *cParamVal = env->GetByteArrayElements(jval, NULL);
325       cParamVals[i] = cParamVal;
326       env->DeleteLocalRef(jval);
327     }
328   }
329   int ret = MXFuncInvokeEx(reinterpret_cast<FunctionHandle>(funcPtr),
330                            reinterpret_cast<NDArrayHandle *>(cUseVars),
331                            reinterpret_cast<mx_float *>(cScalarArgs),
332                            reinterpret_cast<NDArrayHandle *>(cMutateVars),
333                            static_cast<int>(numParams),
334                            reinterpret_cast<char **>(cParamKeys),
335                            reinterpret_cast<char **>(cParamVals));
336   env->ReleaseLongArrayElements(useVars, cUseVars, 0);
337   env->ReleaseFloatArrayElements(scalarArgs, cScalarArgs, 0);
338   env->ReleaseLongArrayElements(mutateVars, cMutateVars, 0);
339   if (numParams > 0) {
340     for (int i = 0; i < numParams; i++) {
341       jbyteArray jkey = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramKeys, i));
342       env->ReleaseByteArrayElements(jkey, cParamKeys[i], 0);
343       env->DeleteLocalRef(jkey);
344       jbyteArray jval = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramVals, i));
345       env->ReleaseByteArrayElements(jval, cParamVals[i], 0);
346       env->DeleteLocalRef(jval);
347     }
348     delete[] cParamKeys;
349     delete[] cParamVals;
350   }
351   return ret;
352 }
353 
Java_org_apache_mxnet_LibInfo_mxNDArraySaveRawBytes(JNIEnv * env,jobject obj,jlong ndArrayPtr,jobject dataBuf)354 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySaveRawBytes
355   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject dataBuf) {
356   size_t length;
357   const char *pdata;
358   int ret = MXNDArraySaveRawBytes(reinterpret_cast<NDArrayHandle>(ndArrayPtr), &length, &pdata);
359 
360   // fill dataBuf
361   jclass byteClass = env->FindClass("java/lang/Byte");
362   jmethodID newByte = env->GetMethodID(byteClass, "<init>", "(B)V");
363   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
364   jmethodID arrayAppend = env->GetMethodID(arrayClass,
365     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
366   for (size_t i = 0; i < length; ++i) {
367     jobject data = env->NewObject(byteClass, newByte, static_cast<jbyte>(pdata[i]));
368     env->CallObjectMethod(dataBuf, arrayAppend, data);
369     env->DeleteLocalRef(data);
370   }
371 
372   return ret;
373 }
374 
Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes(JNIEnv * env,jobject obj,jbyteArray bytes,jobject handleRef)375 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes
376   (JNIEnv *env, jobject obj, jbyteArray bytes, jobject handleRef) {
377   int size = env->GetArrayLength(bytes);
378   jbyte *byteArr = env->GetByteArrayElements(bytes, NULL);
379   NDArrayHandle out;
380   int ret = MXNDArrayLoadFromRawBytes(reinterpret_cast<const void *>(byteArr),
381                                       static_cast<size_t>(size), &out);
382   env->ReleaseByteArrayElements(bytes, byteArr, 0);
383   SetLongField(env, handleRef, reinterpret_cast<jlong>(out));
384   return ret;
385 }
386 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv * env,jobject obj,jlong ndArrayPtr,jobject ndimRef,jobject dataBuf)387 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
388   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject ndimRef, jobject dataBuf) {
389   int ndim;
390   const int *pdata;
391   int ret = MXNDArrayGetShapeEx(reinterpret_cast<NDArrayHandle>(ndArrayPtr), &ndim, &pdata);
392 
393   // fill dataBuf
394   jclass integerClass = env->FindClass("java/lang/Integer");
395   jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V");
396 
397   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
398   jmethodID arrayAppend = env->GetMethodID(arrayClass,
399     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
400   for (int i = 0; i < ndim; ++i) {
401     jobject data = env->NewObject(integerClass, newInteger, pdata[i]);
402     env->CallObjectMethod(dataBuf, arrayAppend, data);
403     env->DeleteLocalRef(data);
404   }
405 
406   // set ndimRef
407   jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
408   jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
409   env->SetIntField(ndimRef, valueInt, ndim);
410 
411   return ret;
412 }
413 
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray(JNIEnv * env,jobject obj,jlong dstPtr,jlong srcPtr,jint locator)414 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray
415   (JNIEnv *env, jobject obj, jlong dstPtr, jlong srcPtr, jint locator) {
416   int ret = MXNDArraySyncCopyFromNDArray(reinterpret_cast<NDArrayHandle>(dstPtr),
417                                    reinterpret_cast<NDArrayHandle>(srcPtr),
418                                    locator);
419   return ret;
420 }
421 
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU(JNIEnv * env,jobject obj,jlong ndArrayPtr,jbyteArray data,jint size)422 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU
423   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jbyteArray data, jint size) {
424   jbyte *pdata = env->GetByteArrayElements(data, NULL);
425   int ret = MXNDArraySyncCopyToCPU(reinterpret_cast<NDArrayHandle>(ndArrayPtr),
426                                    reinterpret_cast<void *>(pdata), size);
427   env->ReleaseByteArrayElements(data, pdata, 0);  // copy back to java array automatically
428   return ret;
429 }
430 
Java_org_apache_mxnet_LibInfo_mxNDArraySlice(JNIEnv * env,jobject obj,jlong ndArrayPtr,jint start,jint end,jobject slicedHandle)431 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySlice
432   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint start, jint end, jobject slicedHandle) {
433   NDArrayHandle out;
434   int ret = MXNDArraySlice(reinterpret_cast<NDArrayHandle>(ndArrayPtr), start, end, &out);
435   SetLongField(env, slicedHandle, reinterpret_cast<jlong>(out));
436   return ret;
437 }
438 
Java_org_apache_mxnet_LibInfo_mxNDArrayAt(JNIEnv * env,jobject obj,jlong ndArrayPtr,jint idx,jobject jout)439 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt
440   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint idx, jobject jout) {
441   NDArrayHandle out;
442   int ret = MXNDArrayAt(reinterpret_cast<NDArrayHandle>(ndArrayPtr), idx, &out);
443   SetLongField(env, jout, reinterpret_cast<jlong>(out));
444   return ret;
445 }
446 
Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64(JNIEnv * env,jobject obj,jlong ndArrayPtr,jint ndim,jlongArray dims,jboolean reverse,jobject reshapedHandle)447 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
448   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim,
449    jlongArray dims, jboolean reverse, jobject reshapedHandle) {
450   NDArrayHandle out;
451   jlong *pdims = env->GetLongArrayElements(dims, NULL);
452   int ret = MXNDArrayReshape64(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
453                                     reinterpret_cast<dim_t *>(pdims), reverse, &out);
454   SetLongField(env, reshapedHandle, reinterpret_cast<jlong>(out));
455   env->ReleaseLongArrayElements(dims, pdims, 0);
456   return ret;
457 }
458 
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU(JNIEnv * env,jobject obj,jlong arrayPtr,jfloatArray sourceArr,jint arrSize)459 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
460   (JNIEnv *env, jobject obj, jlong arrayPtr, jfloatArray sourceArr, jint arrSize) {
461   jfloat *sourcePtr = env->GetFloatArrayElements(sourceArr, NULL);
462   int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
463                                      static_cast<const mx_float *>(sourcePtr), arrSize);
464   env->ReleaseFloatArrayElements(sourceArr, sourcePtr, 0);
465   return ret;
466 }
467 
Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU(JNIEnv * env,jobject obj,jlong arrayPtr,jdoubleArray sourceArr,jint arrSize)468 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
469   (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) {
470   jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL);
471   int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
472                                      static_cast<const double *>(sourcePtr), arrSize);
473   env->ReleaseDoubleArrayElements(sourceArr, sourcePtr, 0);
474   return ret;
475 }
476 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray(JNIEnv * env,jobject obj,jlong arrayPtr,jobject ndArrayHandle)477 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
478   (JNIEnv *env, jobject obj, jlong arrayPtr, jobject ndArrayHandle) {
479   NDArrayHandle out;
480   int ret = MXNDArrayGetDataNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
481                                      &out);
482   SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
483   return ret;
484 }
485 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray(JNIEnv * env,jobject obj,jlong arrayPtr,jint location,jobject ndArrayHandle)486 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
487   (JNIEnv *env, jobject obj, jlong arrayPtr, jint location, jobject ndArrayHandle) {
488   NDArrayHandle out;
489   int ret = MXNDArrayGetAuxNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
490                                    static_cast<mx_uint>(location),
491                                    &out);
492   SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
493   return ret;
494 }
495 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext(JNIEnv * env,jobject obj,jlong arrayPtr,jobject devTypeId,jobject devId)496 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
497   (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) {
498   int outDevType;
499   int outDevId;
500   int ret = MXNDArrayGetContext(reinterpret_cast<NDArrayHandle>(arrayPtr), &outDevType, &outDevId);
501   jclass refClass = env->FindClass("org/apache/mxnet/Base$RefInt");
502   jfieldID refFid = env->GetFieldID(refClass, "value", "I");
503   env->SetIntField(devTypeId, refFid, outDevType);
504   env->SetIntField(devId, refFid, outDevId);
505   return ret;
506 }
507 
Java_org_apache_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env,jobject obj,jlong ndArrayHandle)508 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayFree
509   (JNIEnv * env, jobject obj, jlong ndArrayHandle) {
510   return MXNDArrayFree(reinterpret_cast<NDArrayHandle>(ndArrayHandle));
511 }
512 
Java_org_apache_mxnet_LibInfo_mxNDArrayLoad(JNIEnv * env,jobject obj,jstring jfname,jobject joutSize,jobject jhandles,jobject joutNameSize,jobject jnames)513 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoad
514   (JNIEnv * env, jobject obj, jstring jfname, jobject joutSize,
515     jobject jhandles, jobject joutNameSize, jobject jnames) {
516   mx_uint outSize;
517   NDArrayHandle *outArr;
518   mx_uint outNameSize;
519   const char **outNames;
520 
521   const char *fname = env->GetStringUTFChars(jfname, 0);
522   int ret = MXNDArrayLoad(fname, &outSize, &outArr, &outNameSize, &outNames);
523   env->ReleaseStringUTFChars(jfname, fname);
524 
525   if (ret) {
526     return ret;
527   }
528 
529   // fill sizes
530   jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
531   jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
532   env->SetIntField(joutSize, valueInt, outSize);
533   env->SetIntField(joutNameSize, valueInt, outNameSize);
534 
535   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
536   jmethodID arrayAppend = env->GetMethodID(arrayClass,
537     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
538 
539   // fill handles
540   jclass longCls = env->FindClass("java/lang/Long");
541   jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
542   for (size_t i = 0; i < outSize; ++i) {
543     jobject handle = env->NewObject(longCls, longConst, outArr[i]);
544     env->CallObjectMethod(jhandles, arrayAppend, handle);
545     env->DeleteLocalRef(handle);
546   }
547 
548   // fill names
549   for (size_t i = 0; i < outNameSize; ++i) {
550     jstring jname = env->NewStringUTF(outNames[i]);
551     env->CallObjectMethod(jnames, arrayAppend, jname);
552     env->DeleteLocalRef(jname);
553   }
554 
555   return ret;
556 }
557 
Java_org_apache_mxnet_LibInfo_mxNDArraySave(JNIEnv * env,jobject obj,jstring jfname,jlongArray jhandles,jobjectArray jkeys)558 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySave
559   (JNIEnv * env, jobject obj, jstring jfname, jlongArray jhandles, jobjectArray jkeys) {
560   int numArgs = env->GetArrayLength(jhandles);
561   const char **keys = NULL;
562   if (jkeys != NULL) {
563     keys = new const char *[numArgs];
564     for (int i = 0; i < numArgs; i++) {
565       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
566       const char *key = env->GetStringUTFChars(jkey, 0);
567       keys[i] = key;
568       env->DeleteLocalRef(jkey);
569     }
570   }
571 
572   const char *fname = env->GetStringUTFChars(jfname, 0);
573   jlong *handles = env->GetLongArrayElements(jhandles, NULL);
574 
575   int ret = MXNDArraySave(fname, static_cast<mx_uint>(numArgs),
576                           reinterpret_cast<NDArrayHandle *>(handles), keys);
577 
578   env->ReleaseLongArrayElements(jhandles, handles, 0);
579   env->ReleaseStringUTFChars(jfname, fname);
580 
581   // release allocated memory
582   if (jkeys != NULL) {
583     for (int i = 0; i < numArgs; i++) {
584       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
585       env->ReleaseStringUTFChars(jkey, keys[i]);
586       env->DeleteLocalRef(jkey);
587     }
588     delete[] keys;
589   }
590 
591   return ret;
592 }
593 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType(JNIEnv * env,jobject obj,jlong jhandle,jobject jdtype)594 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType
595   (JNIEnv * env, jobject obj, jlong jhandle, jobject jdtype) {
596   int dtype;
597   int ret = MXNDArrayGetDType(reinterpret_cast<NDArrayHandle>(jhandle), &dtype);
598   SetIntField(env, jdtype, dtype);
599   return ret;
600 }
601 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType(JNIEnv * env,jobject obj,jlong jhandle,jobject jstype)602 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType
603   (JNIEnv * env, jobject obj, jlong jhandle, jobject jstype) {
604   int stype;
605   int ret = MXNDArrayGetStorageType(reinterpret_cast<NDArrayHandle>(jhandle), &stype);
606   SetIntField(env, jstype, stype);
607   return ret;
608 }
609 
Java_org_apache_mxnet_LibInfo_mxInitPSEnv(JNIEnv * env,jobject obj,jobjectArray jkeys,jobjectArray jvals)610 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxInitPSEnv
611   (JNIEnv *env, jobject obj, jobjectArray jkeys, jobjectArray jvals) {
612   // keys and values
613   int paramSize = env->GetArrayLength(jkeys);
614   const char** keys = new const char*[paramSize];
615   const char** vals = new const char*[paramSize];
616   jstring jkey, jval;
617   // use strcpy and release char* created by JNI inplace
618   for (int i = 0; i < paramSize; i++) {
619     jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
620     const char* ckey = env->GetStringUTFChars(jkey, 0);
621     keys[i] = ckey;
622     env->DeleteLocalRef(jkey);
623 
624     jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
625     const char* cval = env->GetStringUTFChars(jval, 0);
626     vals[i] = cval;
627     env->DeleteLocalRef(jval);
628   }
629 
630   int ret = MXInitPSEnv(static_cast<mx_uint>(paramSize),
631                         static_cast<const char**>(keys),
632                         static_cast<const char**>(vals));
633 
634   // release keys and vals
635   for (int i = 0; i < paramSize; i++) {
636     jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
637     env->ReleaseStringUTFChars(key, keys[i]);
638     env->DeleteLocalRef(key);
639 
640     jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
641     env->ReleaseStringUTFChars(value, vals[i]);
642     env->DeleteLocalRef(value);
643   }
644   delete[] keys;
645   delete[] vals;
646 
647   return ret;
648 }
649 
KVStoreServerControllerFunc(int head,const char * body,void * handle)650 extern "C" void KVStoreServerControllerFunc
651   (int head, const char *body, void *handle) {
652   jobject controllerObjGlb = static_cast<jobject>(handle);
653 
654   JNIEnv *env;
655   _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
656 
657   // find java controller method
658   jclass ctrlClass = env->GetObjectClass(controllerObjGlb);
659   jmethodID ctrlFunc = env->GetMethodID(ctrlClass, "invoke", "(ILjava/lang/String;)V");
660 
661   jstring jbody = env->NewStringUTF(body);
662   env->CallVoidMethod(controllerObjGlb, ctrlFunc, head, jbody);
663   env->DeleteLocalRef(jbody);
664 
665   env->DeleteLocalRef(ctrlClass);
666   // FIXME(Yizhi): This function can be called multiple times,
667   // can we find a way to safely destroy this object ?
668   // env->DeleteGlobalRef(controllerObjGlb);
669 }
670 
Java_org_apache_mxnet_LibInfo_mxKVStoreRunServer(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject controllerObj)671 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreRunServer
672   (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject controllerObj) {
673   jobject controllerObjGlb = env->NewGlobalRef(controllerObj);
674   return MXKVStoreRunServer(reinterpret_cast<KVStoreHandle>(kvStorePtr),
675                             KVStoreServerControllerFunc,
676                             reinterpret_cast<void *>(controllerObjGlb));
677 }
678 
KVStoreUpdaterCallbackFunc(int key,NDArrayHandle recv,NDArrayHandle local,void * handle)679 extern "C" void KVStoreUpdaterCallbackFunc
680   (int key, NDArrayHandle recv, NDArrayHandle local, void *handle) {
681   jobject updaterFuncObjGlb = static_cast<jobject>(handle);
682 
683   JNIEnv *env;
684   _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
685 
686   // find java updater method
687   jclass updtClass = env->GetObjectClass(updaterFuncObjGlb);
688   jmethodID updtFunc = env->GetMethodID(updtClass,
689     "update", "(ILorg/apache/mxnet/NDArray;Lorg/apache/mxnet/NDArray;)V");
690 
691   // find java NDArray constructor
692   jclass ndObjClass = env->FindClass("org/apache/mxnet/NDArray");
693   jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "<init>", "(JZZ)V");
694 
695   jobject ndRecv = env->NewObject(ndObjClass, ndObjConstructor,
696                                   reinterpret_cast<jlong>(recv), true);
697   jobject ndLocal = env->NewObject(ndObjClass, ndObjConstructor,
698                                    reinterpret_cast<jlong>(local), true);
699 
700   env->CallVoidMethod(updaterFuncObjGlb, updtFunc, key, ndRecv, ndLocal);
701 
702   env->DeleteLocalRef(ndLocal);
703   env->DeleteLocalRef(ndRecv);
704   env->DeleteLocalRef(ndObjClass);
705   env->DeleteLocalRef(updtClass);
706   // FIXME(Yizhi): This function can be called multiple times,
707   // can we find a way to safely destroy this object ?
708   // env->DeleteGlobalRef(updaterFuncObjGlb);
709 }
710 
Java_org_apache_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject updaterFuncObj)711 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSetUpdater
712   (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj) {
713   jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj);
714   return MXKVStoreSetUpdater(reinterpret_cast<KVStoreHandle>(kvStorePtr),
715                              KVStoreUpdaterCallbackFunc,
716                              reinterpret_cast<void *>(updaterFuncObjGlb));
717 }
718 
Java_org_apache_mxnet_LibInfo_mxKVStoreIsWorkerNode(JNIEnv * env,jobject obj,jobject isWorkerRef)719 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreIsWorkerNode
720   (JNIEnv *env, jobject obj, jobject isWorkerRef) {
721   int isWorker;
722   int ret = MXKVStoreIsWorkerNode(&isWorker);
723   SetIntField(env, isWorkerRef, isWorker);
724   return ret;
725 }
726 
Java_org_apache_mxnet_LibInfo_mxKVStoreCreate(JNIEnv * env,jobject obj,jstring name,jobject kvStoreHandle)727 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreCreate
728   (JNIEnv *env, jobject obj, jstring name, jobject kvStoreHandle) {
729   jclass refLongClass = env->FindClass("org/apache/mxnet/Base$RefLong");
730   jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
731 
732   KVStoreHandle out;
733   const char *type = env->GetStringUTFChars(name, 0);
734   int ret = MXKVStoreCreate(type, &out);
735   env->ReleaseStringUTFChars(name, type);
736 
737   env->SetLongField(kvStoreHandle, refLongFid, reinterpret_cast<jlong>(out));
738   return ret;
739 }
740 
Java_org_apache_mxnet_LibInfo_mxKVStoreInit(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jintArray keys,jlongArray values)741 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreInit
742   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys, jlongArray values) {
743   jint *keyArray = env->GetIntArrayElements(keys, NULL);
744   jlong *valueArray = env->GetLongArrayElements(values, NULL);
745   int ret = MXKVStoreInit(reinterpret_cast<KVStoreHandle>(kvStorePtr),
746                           static_cast<mx_uint>(len),
747                           static_cast<const int *>(keyArray),
748                           reinterpret_cast<NDArrayHandle *>(valueArray));
749   env->ReleaseIntArrayElements(keys, keyArray, 0);
750   env->ReleaseLongArrayElements(values, valueArray, 0);
751   return ret;
752 }
753 
Java_org_apache_mxnet_LibInfo_mxKVStoreInitEx(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jobjectArray keys,jlongArray values)754 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreInitEx
755   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys, jlongArray values) {
756   const char **keyArray = new const char *[len];
757   for (int i = 0; i < len; i++) {
758     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
759     const char *key = env->GetStringUTFChars(jkey, 0);
760     keyArray[i] = key;
761     env->DeleteLocalRef(jkey);
762   }
763   jlong *valueArray = env->GetLongArrayElements(values, NULL);
764   int ret = MXKVStoreInitEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
765                           static_cast<mx_uint>(len),
766                           keyArray,
767                           reinterpret_cast<NDArrayHandle *>(valueArray));
768   env->ReleaseLongArrayElements(values, valueArray, 0);
769   for (int i = 0; i < len; i++) {
770     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
771     env->ReleaseStringUTFChars(jkey, keyArray[i]);
772     env->DeleteLocalRef(jkey);
773   }
774   delete[] keyArray;
775   return ret;
776 }
777 
Java_org_apache_mxnet_LibInfo_mxKVStorePush(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jintArray keys,jlongArray values,jint priority)778 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePush
779   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
780     jlongArray values, jint priority) {
781   jint *keyArray = env->GetIntArrayElements(keys, NULL);
782   jlong *valueArray = env->GetLongArrayElements(values, NULL);
783   int ret = MXKVStorePush(reinterpret_cast<KVStoreHandle>(kvStorePtr),
784                           static_cast<mx_uint>(len),
785                           static_cast<const int *>(keyArray),
786                           reinterpret_cast<NDArrayHandle *>(valueArray),
787                           priority);
788   env->ReleaseLongArrayElements(values, valueArray, 0);
789   return ret;
790 }
791 
Java_org_apache_mxnet_LibInfo_mxKVStorePushEx(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jobjectArray keys,jlongArray values,jint priority)792 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePushEx
793   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
794     jlongArray values, jint priority) {
795   const char **keyArray = new const char *[len];
796   for (int i = 0; i < len; i++) {
797     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
798     const char *key = env->GetStringUTFChars(jkey, 0);
799     keyArray[i] = key;
800     env->DeleteLocalRef(jkey);
801   }
802   jlong *valueArray = env->GetLongArrayElements(values, NULL);
803   int ret = MXKVStorePushEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
804                           static_cast<mx_uint>(len),
805                           keyArray,
806                           reinterpret_cast<NDArrayHandle *>(valueArray),
807                           priority);
808   env->ReleaseLongArrayElements(values, valueArray, 0);
809   for (int i = 0; i < len; i++) {
810     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
811     env->ReleaseStringUTFChars(jkey, keyArray[i]);
812     env->DeleteLocalRef(jkey);
813   }
814   delete[] keyArray;
815   return ret;
816 }
817 
Java_org_apache_mxnet_LibInfo_mxKVStorePull(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jintArray keys,jlongArray outs,jint priority)818 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePull
819   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
820     jlongArray outs, jint priority) {
821   jint *keyArray = env->GetIntArrayElements(keys, NULL);
822   jlong *outArray = env->GetLongArrayElements(outs, NULL);
823   int ret = MXKVStorePull(reinterpret_cast<KVStoreHandle>(kvStorePtr),
824                           static_cast<mx_uint>(len),
825                           static_cast<const int *>(keyArray),
826                           reinterpret_cast<NDArrayHandle *>(outArray),
827                           priority);
828   env->ReleaseIntArrayElements(keys, keyArray, 0);
829   env->ReleaseLongArrayElements(outs, outArray, 0);
830   return ret;
831 }
832 
Java_org_apache_mxnet_LibInfo_mxKVStorePullEx(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jobjectArray keys,jlongArray outs,jint priority)833 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePullEx
834   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
835     jlongArray outs, jint priority) {
836   const char **keyArray = new const char *[len];
837   for (int i = 0; i < len; i++) {
838     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
839     const char *key = env->GetStringUTFChars(jkey, 0);
840     keyArray[i] = key;
841     env->DeleteLocalRef(jkey);
842   }
843   jlong *outArray = env->GetLongArrayElements(outs, NULL);
844   int ret = MXKVStorePullEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
845                           static_cast<mx_uint>(len),
846                           keyArray,
847                           reinterpret_cast<NDArrayHandle *>(outArray),
848                           priority);
849   env->ReleaseLongArrayElements(outs, outArray, 0);
850   for (int i = 0; i < len; i++) {
851     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
852     env->ReleaseStringUTFChars(jkey, keyArray[i]);
853     env->DeleteLocalRef(jkey);
854   }
855   delete[] keyArray;
856   return ret;
857 }
858 
Java_org_apache_mxnet_LibInfo_mxKVStoreGetType(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject kvType)859 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetType
860   (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject kvType) {
861   const char *type;
862   int ret = MXKVStoreGetType(reinterpret_cast<KVStoreHandle>(kvStorePtr), &type);
863   jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
864   jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
865   env->SetObjectField(kvType, valueStr, env->NewStringUTF(type));
866   return ret;
867 }
868 
Java_org_apache_mxnet_LibInfo_mxKVStoreSendCommmandToServers(JNIEnv * env,jobject obj,jlong kvStorePtr,jint head,jstring body)869 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSendCommmandToServers
870   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint head, jstring body) {
871   const char *bodyCStr = env->GetStringUTFChars(body, 0);
872   int ret = MXKVStoreSendCommmandToServers(
873     reinterpret_cast<KVStoreHandle>(kvStorePtr), head, bodyCStr);
874   env->ReleaseStringUTFChars(body, bodyCStr);
875   return ret;
876 }
877 
Java_org_apache_mxnet_LibInfo_mxKVStoreBarrier(JNIEnv * env,jobject obj,jlong kvStorePtr)878 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreBarrier
879   (JNIEnv *env, jobject obj, jlong kvStorePtr) {
880   return MXKVStoreBarrier((KVStoreHandle)kvStorePtr);
881 }
882 
Java_org_apache_mxnet_LibInfo_mxKVStoreGetGroupSize(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject sizeRef)883 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetGroupSize
884   (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject sizeRef) {
885   int size;
886   int ret = MXKVStoreGetGroupSize(reinterpret_cast<KVStoreHandle>(kvStorePtr), &size);
887   SetIntField(env, sizeRef, size);
888   return ret;
889 }
890 
Java_org_apache_mxnet_LibInfo_mxKVStoreGetRank(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject rankRef)891 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetRank
892   (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject rankRef) {
893   int rank;
894   int ret = MXKVStoreGetRank(reinterpret_cast<KVStoreHandle>(kvStorePtr), &rank);
895   SetIntField(env, rankRef, rank);
896   return ret;
897 }
898 
Java_org_apache_mxnet_LibInfo_mxKVStoreGetNumDeadNode(JNIEnv * env,jobject obj,jlong kvStorePtr,jint nodeId,jobject numberRef)899 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetNumDeadNode
900   (JNIEnv * env, jobject obj, jlong kvStorePtr, jint nodeId, jobject numberRef) {
901   int number;
902   int ret = MXKVStoreGetNumDeadNode(reinterpret_cast<KVStoreHandle>(kvStorePtr),
903                                     static_cast<const int>(nodeId),
904                                     &number);
905   SetIntField(env, numberRef, number);
906   return ret;
907 }
908 
Java_org_apache_mxnet_LibInfo_mxKVStoreSetBarrierBeforeExit(JNIEnv * env,jobject obj,jlong kvStorePtr,jint doBarrier)909 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSetBarrierBeforeExit
910   (JNIEnv * env, jobject obj, jlong kvStorePtr, jint doBarrier) {
911   return MXKVStoreSetBarrierBeforeExit(reinterpret_cast<KVStoreHandle>(kvStorePtr),
912                                        static_cast<const int>(doBarrier));
913 }
914 
Java_org_apache_mxnet_LibInfo_mxKVStoreFree(JNIEnv * env,jobject obj,jlong ptr)915 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreFree
916   (JNIEnv * env, jobject obj, jlong ptr) {
917   return MXKVStoreFree(reinterpret_cast<KVStoreHandle>(ptr));
918 }
919 
Java_org_apache_mxnet_LibInfo_mxExecutorOutputs(JNIEnv * env,jobject obj,jlong executorPtr,jobject outputs)920 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorOutputs
921   (JNIEnv *env, jobject obj, jlong executorPtr, jobject outputs) {
922   mx_uint outSize;
923   NDArrayHandle *out;
924   int ret = MXExecutorOutputs(reinterpret_cast<ExecutorHandle>(executorPtr), &outSize, &out);
925 
926   jclass longCls = env->FindClass("java/lang/Long");
927   jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
928 
929   // fill java outputs
930   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
931   jmethodID arrayAppend = env->GetMethodID(arrayClass,
932     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
933   for (size_t i = 0; i < outSize; ++i) {
934     env->CallObjectMethod(outputs, arrayAppend,
935                           env->NewObject(longCls, longConst, out[i]));
936   }
937 
938   return ret;
939 }
940 
Java_org_apache_mxnet_LibInfo_mxExecutorFree(JNIEnv * env,jobject obj,jlong ptr)941 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorFree
942   (JNIEnv * env, jobject obj, jlong ptr) {
943   return MXExecutorFree(reinterpret_cast<ExecutorHandle>(ptr));
944 }
945 
Java_org_apache_mxnet_LibInfo_mxExecutorForward(JNIEnv * env,jobject obj,jlong ptr,jint isTrain)946 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorForward
947   (JNIEnv * env, jobject obj, jlong ptr, jint isTrain) {
948   return MXExecutorForward(reinterpret_cast<ExecutorHandle>(ptr), static_cast<int>(isTrain));
949 }
950 
Java_org_apache_mxnet_LibInfo_mxExecutorBackward(JNIEnv * env,jobject obj,jlong executorPtr,jlongArray grads)951 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBackward
952   (JNIEnv * env, jobject obj, jlong executorPtr, jlongArray grads) {
953   int gradsSize = env->GetArrayLength(grads);
954   jlong *gradArr = env->GetLongArrayElements(grads, NULL);
955   int ret = MXExecutorBackward(reinterpret_cast<ExecutorHandle>(executorPtr),
956                                static_cast<mx_uint>(gradsSize),
957                                reinterpret_cast<NDArrayHandle *>(gradArr));
958   env->ReleaseLongArrayElements(grads, gradArr, 0);
959   return ret;
960 }
961 
Java_org_apache_mxnet_LibInfo_mxExecutorReshape(JNIEnv * env,jobject obj,jint partialReshaping,jint allowUpSizing,jint devType,jint devId,jobjectArray jmapKeys,jintArray jmapDevTypes,jintArray jmapDevIds,jobjectArray jprovidedArgShapeNames,jintArray jprovidedArgShapeData,jintArray jprovidedArgShapeIdx,jobject jrefInArgs,jobject jrefArgGrads,jobject jrefAuxStates,jlong jsharedExec,jobject jrefOut)962 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape
963   (JNIEnv * env, jobject obj,
964     jint partialReshaping, jint allowUpSizing, jint devType, jint devId,
965     jobjectArray jmapKeys, jintArray jmapDevTypes, jintArray jmapDevIds,
966     jobjectArray jprovidedArgShapeNames, jintArray jprovidedArgShapeData,
967     jintArray jprovidedArgShapeIdx, jobject jrefInArgs, jobject jrefArgGrads,
968     jobject jrefAuxStates, jlong jsharedExec, jobject jrefOut) {
969   CHECK(jmapKeys != NULL);
970   CHECK(jprovidedArgShapeNames != NULL);
971 
972   int numMapKeys = env->GetArrayLength(jmapKeys);
973   jint *mapDevTypes = env->GetIntArrayElements(jmapDevTypes, NULL);
974   jint *mapDevIds = env->GetIntArrayElements(jmapDevIds, NULL);
975   const char **mapKeys = NULL;
976   if (numMapKeys > 0) {
977     mapKeys = new const char*[numMapKeys];
978     for (int i = 0; i < numMapKeys; ++i) {
979       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
980       mapKeys[i] = env->GetStringUTFChars(jkey, 0);
981       env->DeleteLocalRef(jkey);
982     }
983   }
984 
985   int numProvidedArgShapes = env->GetArrayLength(jprovidedArgShapeNames);
986   jint *providedArgShapeData = env->GetIntArrayElements(jprovidedArgShapeData, NULL);
987   jint *providedArgShapeIdx = env->GetIntArrayElements(jprovidedArgShapeIdx, NULL);
988   const char **providedArgShapeNames = NULL;
989   if (numProvidedArgShapes > 0) {
990     providedArgShapeNames = new const char*[numProvidedArgShapes];
991     for (int i = 0; i < numProvidedArgShapes; ++i) {
992       jstring jkey = reinterpret_cast<jstring>(
993           env->GetObjectArrayElement(jprovidedArgShapeNames, i));
994       providedArgShapeNames[i] = env->GetStringUTFChars(jkey, 0);
995       env->DeleteLocalRef(jkey);
996     }
997   }
998 
999   mx_uint numInArgs = 0;
1000   NDArrayHandle *inArgs;
1001   NDArrayHandle *argGrads;
1002 
1003   mx_uint numAuxStates = 0;
1004   NDArrayHandle *auxStates;
1005 
1006   ExecutorHandle out;
1007 
1008   int ret = MXExecutorReshapeEx(partialReshaping,
1009                                 allowUpSizing,
1010                                 devType,
1011                                 devId,
1012                                 static_cast<mx_uint>(numMapKeys),
1013                                 mapKeys,
1014                                 static_cast<const int*>(mapDevTypes),
1015                                 static_cast<const int*>(mapDevIds),
1016                                 static_cast<const mx_uint>(numProvidedArgShapes),
1017                                 providedArgShapeNames,
1018                                 static_cast<const int*>(providedArgShapeData),
1019                                 reinterpret_cast<const mx_uint*>(providedArgShapeIdx),
1020                                 &numInArgs,
1021                                 &inArgs,
1022                                 &argGrads,
1023                                 &numAuxStates,
1024                                 &auxStates,
1025                                 reinterpret_cast<ExecutorHandle>(jsharedExec),
1026                                 &out);
1027 
1028   jclass longCls = env->FindClass("java/lang/Long");
1029   jmethodID newLong = env->GetMethodID(longCls, "<init>", "(J)V");
1030 
1031   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1032   jmethodID arrayAppend = env->GetMethodID(arrayClass,
1033     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1034 
1035   for (size_t i = 0; i < numInArgs; ++i) {
1036     jobject inArg = env->NewObject(longCls, newLong, inArgs[i]);
1037     env->CallObjectMethod(jrefInArgs, arrayAppend, inArg);
1038     env->DeleteLocalRef(inArg);
1039 
1040     jobject argGrad = env->NewObject(longCls, newLong, argGrads[i]);
1041     env->CallObjectMethod(jrefArgGrads, arrayAppend, argGrad);
1042     env->DeleteLocalRef(argGrad);
1043   }
1044 
1045   for (size_t i = 0; i < numAuxStates; ++i) {
1046     jobject auxState = env->NewObject(longCls, newLong, auxStates[i]);
1047     env->CallObjectMethod(jrefAuxStates, arrayAppend, auxState);
1048     env->DeleteLocalRef(auxState);
1049   }
1050 
1051   SetLongField(env, jrefOut, reinterpret_cast<jlong>(out));
1052 
1053   // release allocated memory
1054   for (int i = 0; i < numMapKeys; i++) {
1055     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
1056     env->ReleaseStringUTFChars(jkey, mapKeys[i]);
1057     env->DeleteLocalRef(jkey);
1058   }
1059   if (mapKeys != NULL) {
1060     delete[] mapKeys;
1061   }
1062 
1063   for (int i = 0; i < numProvidedArgShapes; i++) {
1064     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jprovidedArgShapeNames, i));
1065     env->ReleaseStringUTFChars(jkey, providedArgShapeNames[i]);
1066     env->DeleteLocalRef(jkey);
1067   }
1068   if (providedArgShapeNames != NULL) {
1069     delete[] providedArgShapeNames;
1070   }
1071 
1072   return ret;
1073 }
1074 
Java_org_apache_mxnet_LibInfo_mxExecutorPrint(JNIEnv * env,jobject obj,jlong ptr,jobject debugStr)1075 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorPrint
1076   (JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) {
1077   const char *retDebugStr;
1078   int ret = MXExecutorPrint(reinterpret_cast<ExecutorHandle>(ptr), &retDebugStr);
1079   SetStringField(env, debugStr, retDebugStr);
1080   return ret;
1081 }
1082 
ExecutorMonitorCallbackFunc(const char * name,NDArrayHandle arr,void * handle)1083 extern "C" void ExecutorMonitorCallbackFunc
1084   (const char *name, NDArrayHandle arr, void *handle) {
1085   jobject callbackFuncObjGlb = static_cast<jobject>(handle);
1086 
1087   JNIEnv *env;
1088   _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
1089 
1090   // find java callback method
1091   jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb);
1092   jmethodID callbackFunc = env->GetMethodID(callbackClass, "invoke", "(Ljava/lang/String;J)V");
1093 
1094   // invoke java callback method
1095   jstring jname = env->NewStringUTF(name);
1096   env->CallVoidMethod(callbackFuncObjGlb, callbackFunc, jname, reinterpret_cast<jlong>(arr));
1097   env->DeleteLocalRef(jname);
1098 
1099   env->DeleteLocalRef(callbackClass);
1100   // FIXME(Yizhi): This function can be called multiple times,
1101   // can we find a way to safely destroy this global ref ?
1102   // env->DeleteGlobalRef(callbackFuncObjGlb);
1103 }
Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallback(JNIEnv * env,jobject obj,jlong executorPtr,jobject callbackFuncObj)1104 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallback
1105   (JNIEnv *env, jobject obj, jlong executorPtr, jobject callbackFuncObj) {
1106   jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj);
1107   return MXExecutorSetMonitorCallback(reinterpret_cast<ExecutorHandle>(executorPtr),
1108                                       ExecutorMonitorCallbackFunc,
1109                                       reinterpret_cast<void *>(callbackFuncObjGlb));
1110 }
1111 
Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv * env,jobject obj)1112 JNIEXPORT jstring JNICALL Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) {
1113   return env->NewStringUTF(MXGetLastError());
1114 }
1115 
1116 // IO funcs
Java_org_apache_mxnet_LibInfo_mxListDataIters(JNIEnv * env,jobject obj,jobject creators)1117 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListDataIters
1118   (JNIEnv * env, jobject obj, jobject creators) {
1119   jclass longCls = env->FindClass("java/lang/Long");
1120   jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
1121 
1122   // scala.collection.mutable.ListBuffer append method
1123   jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1124   jmethodID listAppend = env->GetMethodID(listClass,
1125     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1126 
1127   // Get function list
1128   DataIterCreator *outArray;
1129   mx_uint outSize;
1130   int ret = MXListDataIters(&outSize, &outArray);
1131   for (size_t i = 0; i < outSize; ++i) {
1132     env->CallObjectMethod(creators, listAppend,
1133                           env->NewObject(longCls, longConst,
1134                                          reinterpret_cast<uint64_t>(outArray[i])));
1135   }
1136   return ret;
1137 }
1138 
Java_org_apache_mxnet_LibInfo_mxDataIterCreateIter(JNIEnv * env,jobject obj,jlong creator,jobjectArray jkeys,jobjectArray jvals,jobject dataIterHandleRef)1139 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterCreateIter
1140   (JNIEnv * env, jobject obj, jlong creator, jobjectArray jkeys,
1141     jobjectArray jvals, jobject dataIterHandleRef) {
1142   // keys and values
1143   int paramSize = env->GetArrayLength(jkeys);
1144   const char** keys = new const char*[paramSize];
1145   const char** vals = new const char*[paramSize];
1146   jstring jkey, jval;
1147   // use strcpy and release char* created by JNI inplace
1148   for (int i = 0; i < paramSize; i++) {
1149     jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1150     const char* ckey = env->GetStringUTFChars(jkey, 0);
1151     keys[i] = ckey;
1152     env->DeleteLocalRef(jkey);
1153 
1154     jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
1155     const char* cval = env->GetStringUTFChars(jval, 0);
1156     vals[i] = cval;
1157     env->DeleteLocalRef(jval);
1158   }
1159 
1160   // create iter
1161   DataIterHandle out;
1162   int ret = MXDataIterCreateIter(reinterpret_cast<DataIterCreator>(creator),
1163                                  static_cast<mx_uint>(paramSize),
1164                                  static_cast<const char**>(keys),
1165                                  static_cast<const char**>(vals),
1166                                  &out);
1167   SetLongField(env, dataIterHandleRef, reinterpret_cast<jlong>(out));
1168 
1169   // release keys and vals
1170   for (int i = 0; i < paramSize; i++) {
1171     jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1172     env->ReleaseStringUTFChars(key, keys[i]);
1173     env->DeleteLocalRef(key);
1174 
1175     jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
1176     env->ReleaseStringUTFChars(value, vals[i]);
1177     env->DeleteLocalRef(value);
1178   }
1179   delete[] keys;
1180   delete[] vals;
1181 
1182   return ret;
1183 }
1184 
Java_org_apache_mxnet_LibInfo_mxDataIterGetIterInfo(JNIEnv * env,jobject obj,jlong creator,jobject jname,jobject jdesc,jobject jargNames,jobject jargTypeInfos,jobject jargDescs)1185 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetIterInfo
1186   (JNIEnv * env, jobject obj, jlong creator, jobject jname,
1187     jobject jdesc, jobject jargNames, jobject jargTypeInfos, jobject jargDescs) {
1188   const char* name;
1189   const char* description;
1190   mx_uint numArgs;
1191   const char** argNames;
1192   const char** argTypeInfos;
1193   const char** argDescs;
1194   int ret = MXDataIterGetIterInfo(reinterpret_cast<DataIterCreator>(creator),
1195                                    &name,
1196                                    &description,
1197                                    &numArgs,
1198                                    &argNames,
1199                                    &argTypeInfos,
1200                                    &argDescs);
1201 
1202   jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
1203   jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
1204   // set params
1205   env->SetObjectField(jname, valueStr, env->NewStringUTF(name));
1206   env->SetObjectField(jdesc, valueStr, env->NewStringUTF(description));
1207   jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1208   jmethodID listAppend = env->GetMethodID(listClass,
1209     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1210   for (size_t i = 0; i < numArgs; i++) {
1211     env->CallObjectMethod(jargNames, listAppend, env->NewStringUTF(argNames[i]));
1212     env->CallObjectMethod(jargTypeInfos, listAppend, env->NewStringUTF(argTypeInfos[i]));
1213     env->CallObjectMethod(jargDescs, listAppend, env->NewStringUTF(argDescs[i]));
1214   }
1215   return ret;
1216 }
1217 
Java_org_apache_mxnet_LibInfo_mxDataIterFree(JNIEnv * env,jobject obj,jlong handle)1218 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterFree
1219   (JNIEnv *env, jobject obj, jlong handle) {
1220   int ret = MXDataIterFree(reinterpret_cast<DataIterHandle>(handle));
1221   return ret;
1222 }
1223 
Java_org_apache_mxnet_LibInfo_mxDataIterBeforeFirst(JNIEnv * env,jobject obj,jlong handle)1224 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterBeforeFirst
1225   (JNIEnv *env, jobject obj, jlong handle) {
1226   int ret = MXDataIterBeforeFirst(reinterpret_cast<DataIterHandle>(handle));
1227   return ret;
1228 }
1229 
Java_org_apache_mxnet_LibInfo_mxDataIterNext(JNIEnv * env,jobject obj,jlong handle,jobject out)1230 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterNext
1231   (JNIEnv *env, jobject obj, jlong handle, jobject out) {
1232   int cout;
1233   int ret = MXDataIterNext(reinterpret_cast<DataIterHandle>(handle), &cout);
1234   SetIntField(env, out, cout);
1235   return ret;
1236 }
1237 
Java_org_apache_mxnet_LibInfo_mxDataIterGetLabel(JNIEnv * env,jobject obj,jlong handle,jobject ndArrayHandleRef)1238 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetLabel
1239   (JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
1240   NDArrayHandle out;
1241   int ret = MXDataIterGetLabel(reinterpret_cast<DataIterHandle>(handle), &out);
1242   SetLongField(env, ndArrayHandleRef, reinterpret_cast<jlong>(out));
1243   return ret;
1244 }
1245 
Java_org_apache_mxnet_LibInfo_mxDataIterGetData(JNIEnv * env,jobject obj,jlong handle,jobject ndArrayHandleRef)1246 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetData
1247   (JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
1248   NDArrayHandle out;
1249   int ret = MXDataIterGetData(reinterpret_cast<DataIterHandle>(handle), &out);
1250   SetLongField(env, ndArrayHandleRef, reinterpret_cast<jlong>(out));
1251   return ret;
1252 }
1253 
Java_org_apache_mxnet_LibInfo_mxDataIterGetIndex(JNIEnv * env,jobject obj,jlong handle,jobject outIndex,jobject outSize)1254 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetIndex
1255   (JNIEnv *env, jobject obj, jlong handle, jobject outIndex, jobject outSize) {
1256   uint64_t* coutIndex;
1257   uint64_t coutSize;
1258   int ret = MXDataIterGetIndex(reinterpret_cast<DataIterHandle>(handle), &coutIndex, &coutSize);
1259   // set field
1260   SetLongField(env, outSize, static_cast<jlong>(coutSize));
1261   // scala.collection.mutable.ListBuffer append method
1262   jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1263   jmethodID listAppend = env->GetMethodID(listClass,
1264     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1265 
1266   // long class
1267   jclass longCls = env->FindClass("java/lang/Long");
1268   jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
1269 
1270   for (size_t i = 0; i < coutSize; i++) {
1271     env->CallObjectMethod(outIndex, listAppend,
1272                           env->NewObject(longCls, longConst, coutIndex[i]));
1273   }
1274   return ret;
1275 }
1276 
Java_org_apache_mxnet_LibInfo_mxDataIterGetPadNum(JNIEnv * env,jobject obj,jlong handle,jobject pad)1277 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetPadNum
1278   (JNIEnv *env, jobject obj, jlong handle, jobject pad) {
1279   int cpad;
1280   int ret = MXDataIterGetPadNum((DataIterHandle)handle, &cpad);
1281   SetIntField(env, pad, cpad);
1282   return ret;
1283 }
1284 
1285 // Symbol functions
Java_org_apache_mxnet_LibInfo_mxSymbolFree(JNIEnv * env,jobject obj,jlong ptr)1286 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolFree
1287   (JNIEnv * env, jobject obj, jlong ptr) {
1288   return MXSymbolFree((SymbolHandle) ptr);
1289 }
1290 
Java_org_apache_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators(JNIEnv * env,jobject obj,jobject symbolList)1291 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators
1292   (JNIEnv *env, jobject obj, jobject symbolList) {
1293   mx_uint outSize;
1294   AtomicSymbolCreator *outArray;
1295   int ret = MXSymbolListAtomicSymbolCreators(&outSize, &outArray);
1296 
1297   jclass longCls = env->FindClass("java/lang/Long");
1298   jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
1299 
1300   jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer");
1301   jmethodID listAppend = env->GetMethodID(listCls,
1302     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1303 
1304   for (size_t i = 0; i < outSize; ++i) {
1305     env->CallObjectMethod(symbolList, listAppend,
1306                           env->NewObject(longCls, longConst, outArray[i]));
1307   }
1308 
1309   return ret;
1310 }
1311 
Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo(JNIEnv * env,jobject obj,jlong symbolPtr,jobject name,jobject desc,jobject numArgs,jobject argNames,jobject argTypes,jobject argDescs,jobject keyVarNumArgs)1312 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo
1313   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject name, jobject desc, jobject numArgs,
1314     jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs) {
1315 
1316   const char *cName;
1317   const char *cDesc;
1318   mx_uint cNumArgs;
1319   const char **cArgNames;
1320   const char **cArgTypes;
1321   const char **cArgDescs;
1322   const char *cKeyVarNumArgs;
1323 
1324   int ret = MXSymbolGetAtomicSymbolInfo(reinterpret_cast<AtomicSymbolCreator>(symbolPtr),
1325                                         &cName, &cDesc, &cNumArgs,
1326                                         &cArgNames, &cArgTypes, &cArgDescs,
1327                                         &cKeyVarNumArgs);
1328 
1329   jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
1330   jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
1331 
1332   jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
1333   jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
1334 
1335   // scala.collection.mutable.ListBuffer append method
1336   jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1337   jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
1338       "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1339 
1340   env->SetObjectField(name, valueStr, env->NewStringUTF(cName));
1341   env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc));
1342   env->SetObjectField(keyVarNumArgs, valueStr, env->NewStringUTF(cKeyVarNumArgs));
1343   env->SetIntField(numArgs, valueInt, static_cast<jint>(cNumArgs));
1344   for (size_t i = 0; i < cNumArgs; ++i) {
1345     env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i]));
1346     env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i]));
1347     env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i]));
1348   }
1349 
1350   return ret;
1351 }
1352 
Java_org_apache_mxnet_LibInfo_mxSymbolCreateAtomicSymbol(JNIEnv * env,jobject obj,jlong symbolPtr,jobjectArray paramKeys,jobjectArray paramVals,jobject symbolRef)1353 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateAtomicSymbol
1354   (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray paramKeys,
1355     jobjectArray paramVals, jobject symbolRef) {
1356   int paramSize = env->GetArrayLength(paramKeys);
1357   const char **keys = new const char*[paramSize];
1358   const char **vals = new const char*[paramSize];
1359   for (int i = 0; i < paramSize; i++) {
1360     jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
1361     const char *rawKey = env->GetStringUTFChars(key, 0);
1362     keys[i] = rawKey;
1363     env->DeleteLocalRef(key);
1364 
1365     jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
1366     const char *rawValue = env->GetStringUTFChars(value, 0);
1367     vals[i] = rawValue;
1368     env->DeleteLocalRef(value);
1369   }
1370 
1371   SymbolHandle out;
1372   int ret = MXSymbolCreateAtomicSymbol(reinterpret_cast<AtomicSymbolCreator>(symbolPtr),
1373     static_cast<mx_uint>(paramSize), keys, vals, &out);
1374   SetLongField(env, symbolRef, reinterpret_cast<jlong>(out));
1375 
1376   // release keys and vals
1377   for (int i = 0; i < paramSize; i++) {
1378     jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
1379     env->ReleaseStringUTFChars(key, keys[i]);
1380     env->DeleteLocalRef(key);
1381 
1382     jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
1383     env->ReleaseStringUTFChars(value, vals[i]);
1384     env->DeleteLocalRef(value);
1385   }
1386   delete[] keys;
1387   delete[] vals;
1388 
1389   return ret;
1390 }
1391 
Java_org_apache_mxnet_LibInfo_mxSymbolSetAttr(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jkey,jstring jvalue)1392 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSetAttr
1393   (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jstring jvalue) {
1394   const char *ckey = env->GetStringUTFChars(jkey, 0);
1395   const char *cvalue = env->GetStringUTFChars(jvalue, 0);
1396   int ret = MXSymbolSetAttr(reinterpret_cast<SymbolHandle>(symbolPtr), ckey, cvalue);
1397   env->ReleaseStringUTFChars(jkey, ckey);
1398   env->ReleaseStringUTFChars(jvalue, cvalue);
1399   return ret;
1400 }
1401 
Java_org_apache_mxnet_LibInfo_mxSymbolListAttrShallow(JNIEnv * env,jobject obj,jlong symbolPtr,jobject joutSize,jobject jout)1402 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAttrShallow
1403   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject joutSize, jobject jout) {
1404   mx_uint outSize;
1405   const char** out;
1406 
1407   int ret = MXSymbolListAttrShallow(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &out);
1408 
1409   jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
1410   jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
1411   env->SetIntField(joutSize, valueInt, static_cast<jint>(outSize));
1412 
1413   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1414   jmethodID arrayAppend = env->GetMethodID(arrayClass,
1415     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1416   for (size_t i = 0; i < outSize * 2; ++i) {
1417     jstring jtmp = env->NewStringUTF(out[i]);
1418     env->CallObjectMethod(jout, arrayAppend, jtmp);
1419     env->DeleteLocalRef(jtmp);
1420   }
1421 
1422   return ret;
1423 }
1424 
Java_org_apache_mxnet_LibInfo_mxSymbolListAttr(JNIEnv * env,jobject obj,jlong symbolPtr,jobject joutSize,jobject jout)1425 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAttr
1426   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject joutSize, jobject jout) {
1427   mx_uint outSize;
1428   const char** out;
1429 
1430   int ret = MXSymbolListAttr(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &out);
1431 
1432   jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
1433   jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
1434   env->SetIntField(joutSize, valueInt, static_cast<jint>(outSize));
1435 
1436   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1437   jmethodID arrayAppend = env->GetMethodID(arrayClass,
1438     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1439   for (size_t i = 0; i < outSize * 2; ++i) {
1440     jstring jtmp = env->NewStringUTF(out[i]);
1441     env->CallObjectMethod(jout, arrayAppend, jtmp);
1442     env->DeleteLocalRef(jtmp);
1443   }
1444 
1445   return ret;
1446 }
1447 
Java_org_apache_mxnet_LibInfo_mxSymbolCompose(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jname,jobjectArray jkeys,jlongArray jargs)1448 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCompose
1449   (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jname,
1450     jobjectArray jkeys, jlongArray jargs) {
1451   int argSize = env->GetArrayLength(jargs);
1452   const char **keys = NULL;
1453   if (jkeys != NULL) {
1454     keys = new const char*[argSize];
1455     for (int i = 0; i < argSize; i++) {
1456       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1457       const char *key = env->GetStringUTFChars(jkey, 0);
1458       keys[i] = key;
1459       env->DeleteLocalRef(jkey);
1460     }
1461   }
1462   jlong *args = env->GetLongArrayElements(jargs, NULL);
1463   const char *name = env->GetStringUTFChars(jname, 0);
1464   int ret = MXSymbolCompose(reinterpret_cast<SymbolHandle>(symbolPtr),
1465                             name, static_cast<mx_uint>(argSize), keys,
1466                             reinterpret_cast<SymbolHandle *>(args));
1467   env->ReleaseStringUTFChars(jname, name);
1468   env->ReleaseLongArrayElements(jargs, args, 0);
1469   // release allocated memory
1470   if (jkeys != NULL) {
1471     for (int i = 0; i < argSize; i++) {
1472       jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
1473       env->ReleaseStringUTFChars(jkey, keys[i]);
1474       env->DeleteLocalRef(jkey);
1475     }
1476     delete[] keys;
1477   }
1478   return ret;
1479 }
1480 
Java_org_apache_mxnet_LibInfo_mxSymbolCreateVariable(JNIEnv * env,jobject obj,jstring jname,jobject handle)1481 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateVariable
1482   (JNIEnv *env, jobject obj, jstring jname, jobject handle) {
1483   SymbolHandle out;
1484   const char *name = env->GetStringUTFChars(jname, 0);
1485   int ret = MXSymbolCreateVariable(name, &out);
1486   env->ReleaseStringUTFChars(jname, name);
1487   SetLongField(env, handle, reinterpret_cast<jlong>(out));
1488   return ret;
1489 }
1490 
Java_org_apache_mxnet_LibInfo_mxSymbolGetAttr(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jkey,jobject retRef,jobject successRef)1491 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAttr
1492   (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jobject retRef, jobject successRef) {
1493   const char *out;
1494   int success;
1495   const char *key = env->GetStringUTFChars(jkey, 0);
1496   int ret = MXSymbolGetAttr(reinterpret_cast<SymbolHandle>(symbolPtr), key, &out, &success);
1497   env->ReleaseStringUTFChars(jkey, key);
1498 
1499   SetStringField(env, retRef, out);
1500   SetIntField(env, successRef, success);
1501   return ret;
1502 }
1503 
Java_org_apache_mxnet_LibInfo_mxSymbolListArguments(JNIEnv * env,jobject obj,jlong symbolPtr,jobject arguments)1504 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListArguments
1505   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject arguments) {
1506   mx_uint outSize;
1507   const char **outStrArray;
1508   int ret = MXSymbolListArguments(
1509     reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);
1510 
1511   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1512   jmethodID arrayAppend = env->GetMethodID(arrayClass,
1513     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1514   for (size_t i = 0; i < outSize; i++) {
1515     jstring argument = env->NewStringUTF(outStrArray[i]);
1516     env->CallObjectMethod(arguments, arrayAppend, argument);
1517     env->DeleteLocalRef(argument);
1518   }
1519 
1520   return ret;
1521 }
1522 
Java_org_apache_mxnet_LibInfo_mxSymbolListOutputs(JNIEnv * env,jobject obj,jlong symbolPtr,jobject outputs)1523 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListOutputs
1524   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
1525   mx_uint outSize;
1526   const char **outStrArray;
1527   int ret = MXSymbolListOutputs(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);
1528 
1529   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1530   jmethodID arrayAppend = env->GetMethodID(arrayClass,
1531     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1532   for (size_t i = 0; i < outSize; i++) {
1533     jstring output = env->NewStringUTF(outStrArray[i]);
1534     env->CallObjectMethod(outputs, arrayAppend, output);
1535     env->DeleteLocalRef(output);
1536   }
1537 
1538   return ret;
1539 }
1540 
Java_org_apache_mxnet_LibInfo_mxSymbolListAuxiliaryStates(JNIEnv * env,jobject obj,jlong symbolPtr,jobject outputs)1541 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAuxiliaryStates
1542   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
1543   mx_uint outSize;
1544   const char **outStrArray;
1545   int ret = MXSymbolListAuxiliaryStates(
1546     reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);
1547 
1548   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1549   jmethodID arrayAppend = env->GetMethodID(arrayClass,
1550     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1551   for (size_t i = 0; i < outSize; i++) {
1552     jstring output = env->NewStringUTF(outStrArray[i]);
1553     env->CallObjectMethod(outputs, arrayAppend, output);
1554     env->DeleteLocalRef(output);
1555   }
1556 
1557   return ret;
1558 }
1559 
Java_org_apache_mxnet_LibInfo_mxSymbolCopy(JNIEnv * env,jobject obj,jlong symbolPtr,jobject clonedSymbolRef)1560 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCopy
1561   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) {
1562   SymbolHandle clonedSymbol;
1563   int ret = MXSymbolCopy(reinterpret_cast<SymbolHandle>(symbolPtr), &clonedSymbol);
1564   SetLongField(env, clonedSymbolRef, reinterpret_cast<jlong>(clonedSymbol));
1565   return ret;
1566 }
1567 
Java_org_apache_mxnet_LibInfo_mxSymbolCreateGroup(JNIEnv * env,jobject obj,jlongArray jsymbols,jobject out)1568 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateGroup
1569   (JNIEnv *env, jobject obj, jlongArray jsymbols, jobject out) {
1570   int numSymbols = env->GetArrayLength(jsymbols);
1571   SymbolHandle handle;
1572   jlong *symbols = env->GetLongArrayElements(jsymbols, NULL);
1573   int ret = MXSymbolCreateGroup(numSymbols, reinterpret_cast<SymbolHandle *>(symbols), &handle);
1574   env->ReleaseLongArrayElements(jsymbols, symbols, 0);
1575   SetLongField(env, out, reinterpret_cast<jlong>(handle));
1576   return ret;
1577 }
1578 
Java_org_apache_mxnet_LibInfo_mxSymbolPrint(JNIEnv * env,jobject obj,jlong symbolPtr,jobject out)1579 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolPrint
1580   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject out) {
1581   const char *outStr;
1582   int ret = MXSymbolPrint(reinterpret_cast<SymbolHandle>(symbolPtr), &outStr);
1583   SetStringField(env, out, outStr);
1584   return ret;
1585 }
1586 
Java_org_apache_mxnet_LibInfo_mxSymbolGetOutput(JNIEnv * env,jobject obj,jlong symbolPtr,jint index,jobject jout)1587 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetOutput
1588   (JNIEnv *env, jobject obj, jlong symbolPtr, jint index, jobject jout) {
1589   SymbolHandle out;
1590   int ret = MXSymbolGetOutput(reinterpret_cast<SymbolHandle>(symbolPtr),
1591                               static_cast<mx_uint>(index), &out);
1592   SetLongField(env, jout, reinterpret_cast<jlong>(out));
1593   return ret;
1594 }
1595 
Java_org_apache_mxnet_LibInfo_mxSymbolGetInternals(JNIEnv * env,jobject obj,jlong symbolPtr,jobject jout)1596 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetInternals
1597   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) {
1598   SymbolHandle out;
1599   int ret = MXSymbolGetInternals(reinterpret_cast<SymbolHandle>(symbolPtr), &out);
1600   SetLongField(env, jout, reinterpret_cast<jlong>(out));
1601   return ret;
1602 }
1603 
Java_org_apache_mxnet_LibInfo_mxSymbolInferType(JNIEnv * env,jobject obj,jlong symbolPtr,jobjectArray jkeys,jintArray jvals,jobject jargTypeData,jobject joutTypeData,jobject jauxTypeData,jobject jcomplete)1604 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferType
1605   (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray jkeys, jintArray jvals,
1606     jobject jargTypeData, jobject joutTypeData, jobject jauxTypeData, jobject jcomplete) {
1607   int numArgs = env->GetArrayLength(jvals);
1608   const char **keys = NULL;
1609   if (jkeys != NULL) {
1610     keys = new const char *[numArgs];
1611     for (int i = 0; i < numArgs; i++) {
1612       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1613       const char *key = env->GetStringUTFChars(jkey, 0);
1614       keys[i] = key;
1615       env->DeleteLocalRef(jkey);
1616     }
1617   }
1618 
1619   mx_uint inTypeSize;
1620   const int *inTypeData;
1621   mx_uint outTypeSize;
1622   const int *outTypeData;
1623   mx_uint auxTypeSize;
1624   const int *auxTypeData;
1625   int complete;
1626 
1627   jint *vals = env->GetIntArrayElements(jvals, NULL);
1628   int ret = MXSymbolInferType(reinterpret_cast<SymbolHandle>(symbolPtr),
1629                               static_cast<mx_uint>(numArgs), keys,
1630                               static_cast<const int *>(vals),
1631                               &inTypeSize, &inTypeData,
1632                               &outTypeSize, &outTypeData,
1633                               &auxTypeSize, &auxTypeData,
1634                               &complete);
1635   env->ReleaseIntArrayElements(jvals, vals, 0);
1636 
1637   jclass integerClass = env->FindClass("java/lang/Integer");
1638   jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V");
1639 
1640   jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1641   jmethodID listAppend = env->GetMethodID(listClass,
1642     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1643 
1644   for (size_t i = 0; i < inTypeSize; ++i) {
1645     jobject data = env->NewObject(integerClass, newInteger, inTypeData[i]);
1646     env->CallObjectMethod(jargTypeData, listAppend, data);
1647     env->DeleteLocalRef(data);
1648   }
1649   for (size_t i = 0; i < outTypeSize; ++i) {
1650     jobject data = env->NewObject(integerClass, newInteger, outTypeData[i]);
1651     env->CallObjectMethod(joutTypeData, listAppend, data);
1652     env->DeleteLocalRef(data);
1653   }
1654   for (size_t i = 0; i < auxTypeSize; ++i) {
1655     jobject data = env->NewObject(integerClass, newInteger, auxTypeData[i]);
1656     env->CallObjectMethod(jauxTypeData, listAppend, data);
1657     env->DeleteLocalRef(data);
1658   }
1659 
1660   SetIntField(env, jcomplete, complete);
1661 
1662   // release allocated memory
1663   if (jkeys != NULL) {
1664     for (int i = 0; i < numArgs; i++) {
1665       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1666       env->ReleaseStringUTFChars(jkey, keys[i]);
1667       env->DeleteLocalRef(jkey);
1668     }
1669     delete[] keys;
1670   }
1671 
1672   return ret;
1673 }
1674 
Java_org_apache_mxnet_LibInfo_mxSymbolSaveToJSON(JNIEnv * env,jobject obj,jlong symbolPtr,jobject jout)1675 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSaveToJSON
1676   (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) {
1677   const char *out;
1678   int ret = MXSymbolSaveToJSON(reinterpret_cast<SymbolHandle>(symbolPtr), &out);
1679   SetStringField(env, jout, out);
1680   return ret;
1681 }
1682 
Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromJSON(JNIEnv * env,jobject obj,jstring json,jobject jhandleRef)1683 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromJSON
1684   (JNIEnv *env, jobject obj, jstring json, jobject jhandleRef) {
1685   const char *str = env->GetStringUTFChars(json, 0);
1686   SymbolHandle out;
1687   int ret = MXSymbolCreateFromJSON(str, &out);
1688   SetLongField(env, jhandleRef, reinterpret_cast<jlong>(out));
1689   env->ReleaseStringUTFChars(json, str);
1690   return ret;
1691 }
1692 
Java_org_apache_mxnet_LibInfo_mxSymbolSaveToFile(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jfname)1693 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSaveToFile
1694   (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jfname) {
1695   const char *fname = env->GetStringUTFChars(jfname, 0);
1696   int ret = MXSymbolSaveToFile(reinterpret_cast<SymbolHandle>(symbolPtr), fname);
1697   env->ReleaseStringUTFChars(jfname, fname);
1698   return ret;
1699 }
1700 
Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile(JNIEnv * env,jobject obj,jstring jfname,jobject jhandleRef)1701 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile
1702   (JNIEnv *env, jobject obj, jstring jfname, jobject jhandleRef) {
1703   const char *fname = env->GetStringUTFChars(jfname, 0);
1704   SymbolHandle out;
1705   int ret = MXSymbolCreateFromFile(fname, &out);
1706   SetLongField(env, jhandleRef, reinterpret_cast<jlong>(out));
1707   env->ReleaseStringUTFChars(jfname, fname);
1708   return ret;
1709 }
1710 
FillSymbolInferShape(JNIEnv * env,jmethodID listAppend,jobject joutData,int shapeSize,const int * shapeNdim,const int ** shapeData)1711 int FillSymbolInferShape
1712   (JNIEnv *env, jmethodID listAppend, jobject joutData,
1713     int shapeSize, const int *shapeNdim, const int **shapeData) {
1714   for (int i = 0; i < shapeSize; ++i) {
1715     jintArray jshape = NULL;
1716     if (shapeNdim[i] >= 0) {
1717       jshape = env->NewIntArray(shapeNdim[i]);
1718       if (jshape == NULL) {
1719         // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1720         return -1;
1721       }
1722       env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast<const jint *>(shapeData[i]));
1723     }
1724     env->CallObjectMethod(joutData, listAppend, jshape);
1725     env->DeleteLocalRef(jshape);
1726   }
1727   return 0;
1728 }
1729 
SymbolInferShapeHelper(JNIEnv * env,jobject obj,jlong symbolPtr,jint jnumArgs,jobjectArray jkeys,jintArray jargIndPtr,jintArray jargShapeData,jobject jinShapeData,jobject joutShapeData,jobject jauxShapeData,jobject jcomplete,bool partial)1730 int SymbolInferShapeHelper(JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs,
1731                             jobjectArray jkeys, jintArray jargIndPtr, jintArray jargShapeData,
1732                             jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData,
1733                             jobject jcomplete, bool partial) {
1734   const char **keys = NULL;
1735   if (jkeys != NULL) {
1736     keys = new const char *[jnumArgs];
1737     for (int i = 0; i < jnumArgs; i++) {
1738       jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
1739       const char *key = env->GetStringUTFChars(jkey, 0);
1740       keys[i] = key;
1741       env->DeleteLocalRef(jkey);
1742     }
1743   }
1744 
1745   mx_uint inShapeSize;
1746   const int *inShapeNdim;
1747   const int **inShapeData;
1748 
1749   mx_uint outShapeSize;
1750   const int *outShapeNdim;
1751   const int **outShapeData;
1752 
1753   mx_uint auxShapeSize;
1754   const int *auxShapeNdim;
1755   const int **auxShapeData;
1756 
1757   int complete;
1758 
1759   jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL);
1760   jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL);
1761   int ret;
1762   if (!partial) {
1763     ret = MXSymbolInferShapeEx(reinterpret_cast<SymbolHandle>(symbolPtr),
1764                                static_cast<mx_uint>(jnumArgs),
1765                                keys,
1766                                reinterpret_cast<mx_uint *>(argIndPtr),
1767                                reinterpret_cast<const int *>(argShapeData),
1768                                &inShapeSize,
1769                                &inShapeNdim,
1770                                &inShapeData,
1771                                &outShapeSize,
1772                                &outShapeNdim,
1773                                &outShapeData,
1774                                &auxShapeSize,
1775                                &auxShapeNdim,
1776                                &auxShapeData,
1777                                &complete);
1778   } else {
1779     ret = MXSymbolInferShapePartialEx(reinterpret_cast<SymbolHandle>(symbolPtr),
1780                                       static_cast<mx_uint>(jnumArgs),
1781                                       keys,
1782                                       reinterpret_cast<mx_uint *>(argIndPtr),
1783                                       reinterpret_cast<const int *>(argShapeData),
1784                                       &inShapeSize,
1785                                       &inShapeNdim,
1786                                       &inShapeData,
1787                                       &outShapeSize,
1788                                       &outShapeNdim,
1789                                       &outShapeData,
1790                                       &auxShapeSize,
1791                                       &auxShapeNdim,
1792                                       &auxShapeData,
1793                                       &complete);
1794   }
1795   env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0);
1796   env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0);
1797 
1798   if (ret == 0) {
1799     jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1800     jmethodID listAppend = env->GetMethodID(listClass,
1801       "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1802 
1803     if (FillSymbolInferShape(
1804           env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) {
1805       // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1806       return -1;
1807     }
1808     if (FillSymbolInferShape(
1809           env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) {
1810       // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1811       return -1;
1812     }
1813     if (FillSymbolInferShape(
1814           env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) {
1815       // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1816       return -1;
1817     }
1818 
1819     SetIntField(env, jcomplete, complete);
1820   }
1821 
1822   // release allocated memory
1823   if (jkeys != NULL) {
1824     for (int i = 0; i < jnumArgs; i++) {
1825       jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1826       env->ReleaseStringUTFChars(jkey, keys[i]);
1827       env->DeleteLocalRef(jkey);
1828     }
1829     delete[] keys;
1830   }
1831 
1832   return ret;
1833 }
1834 
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape(JNIEnv * env,jobject obj,jlong symbolPtr,jint jnumArgs,jobjectArray jkeys,jintArray jargIndPtr,jintArray jargShapeData,jobject jinShapeData,jobject joutShapeData,jobject jauxShapeData,jobject jcomplete)1835 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
1836   (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys,
1837     jintArray jargIndPtr, jintArray jargShapeData,
1838     jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) {
1839 
1840   return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData,
1841                                 jinShapeData, joutShapeData, jauxShapeData, jcomplete, false);
1842 }
1843 
Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial(JNIEnv * env,jobject obj,jlong symbolPtr,jint jnumArgs,jobjectArray jkeys,jintArray jargIndPtr,jintArray jargShapeData,jobject jinShapeData,jobject joutShapeData,jobject jauxShapeData,jobject jcomplete)1844 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial
1845   (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys,
1846     jintArray jargIndPtr, jintArray jargShapeData,
1847     jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) {
1848 
1849   return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData,
1850                                 jinShapeData, joutShapeData, jauxShapeData, jcomplete, true);
1851 }
1852 
Java_org_apache_mxnet_LibInfo_mxExecutorBindX(JNIEnv * env,jobject obj,jlong symbolPtr,jint deviceTypeId,jint deviceID,jint numCtx,jobjectArray jctxMapKeys,jintArray jctxMapDevTypes,jintArray jctxMapDevIDs,jint numArgs,jlongArray jargsHandle,jlongArray jargsGradHandle,jintArray jreqsArray,jlongArray jauxArgsHandle,jobject jexecOut)1853 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindX
1854   (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx,
1855     jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs,
1856     jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray,
1857     jlongArray jauxArgsHandle, jobject jexecOut) {
1858   ExecutorHandle out;
1859   int auxStatesLen = env->GetArrayLength(jauxArgsHandle);
1860 
1861   const char **mapKeys = new const char *[numCtx];
1862   for (int i = 0; i < numCtx; i++) {
1863     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jctxMapKeys, i));
1864     const char *key = env->GetStringUTFChars(jkey, 0);
1865     mapKeys[i] = key;
1866     env->DeleteLocalRef(jkey);
1867   }
1868   jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL);
1869   jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL);
1870   jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL);
1871   jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL);
1872   jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL);
1873   jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL);
1874   int ret = MXExecutorBindX(reinterpret_cast<SymbolHandle>(symbolPtr),
1875                             deviceTypeId,
1876                             deviceID,
1877                             static_cast<mx_uint>(numCtx),
1878                             mapKeys,
1879                             mapDevTypes,
1880                             mapDevIDs,
1881                             static_cast<mx_uint>(numArgs),
1882                             reinterpret_cast<NDArrayHandle *>(inArgs),
1883                             reinterpret_cast<NDArrayHandle *>(argGradStore),
1884                             reinterpret_cast<mx_uint *>(gradReqType),
1885                             static_cast<mx_uint>(auxStatesLen),
1886                             reinterpret_cast<NDArrayHandle *>(auxStates),
1887                             &out);
1888   env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0);
1889   env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0);
1890   env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0);
1891   env->ReleaseLongArrayElements(jargsHandle, inArgs, 0);
1892   env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0);
1893   env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0);
1894   for (int i = 0; i < numCtx; i++) {
1895     jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
1896     env->ReleaseStringUTFChars(jkey, mapKeys[i]);
1897     env->DeleteLocalRef(jkey);
1898   }
1899   delete[] mapKeys;
1900 
1901   SetLongField(env, jexecOut, reinterpret_cast<jlong>(out));
1902   return ret;
1903 }
1904 
Java_org_apache_mxnet_LibInfo_mxExecutorBindEX(JNIEnv * env,jobject obj,jlong symbolPtr,jint deviceTypeId,jint deviceID,jint numCtx,jobjectArray jctxMapKeys,jintArray jctxMapDevTypes,jintArray jctxMapDevIDs,jint numArgs,jlongArray jargsHandle,jlongArray jargsGradHandle,jintArray jreqsArray,jlongArray jauxArgsHandle,jlong jsharedExec,jobject jexecOut)1905 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindEX
1906   (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx,
1907     jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs,
1908     jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray,
1909     jlongArray jauxArgsHandle, jlong jsharedExec, jobject jexecOut) {
1910   ExecutorHandle out;
1911   int auxStatesLen = env->GetArrayLength(jauxArgsHandle);
1912   ExecutorHandle sharedExec = nullptr;
1913   if ((int32_t)jsharedExec != 0) sharedExec = reinterpret_cast<ExecutorHandle>(jsharedExec);
1914 
1915   const char **mapKeys = new const char *[numCtx];
1916   for (int i = 0; i < numCtx; i++) {
1917     jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jctxMapKeys, i));
1918     const char *key = env->GetStringUTFChars(jkey, 0);
1919     mapKeys[i] = key;
1920     env->DeleteLocalRef(jkey);
1921   }
1922   jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL);
1923   jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL);
1924   jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL);
1925   jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL);
1926   jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL);
1927   jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL);
1928   int ret = MXExecutorBindEX(reinterpret_cast<SymbolHandle>(symbolPtr),
1929                             deviceTypeId,
1930                             deviceID,
1931                             static_cast<mx_uint>(numCtx),
1932                             mapKeys,
1933                             mapDevTypes,
1934                             mapDevIDs,
1935                             static_cast<mx_uint>(numArgs),
1936                             reinterpret_cast<NDArrayHandle *>(inArgs),
1937                             reinterpret_cast<NDArrayHandle *>(argGradStore),
1938                             reinterpret_cast<mx_uint *>(gradReqType),
1939                             static_cast<mx_uint>(auxStatesLen),
1940                             reinterpret_cast<NDArrayHandle *>(auxStates),
1941                             sharedExec,
1942                             &out);
1943   env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0);
1944   env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0);
1945   env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0);
1946   env->ReleaseLongArrayElements(jargsHandle, inArgs, 0);
1947   env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0);
1948   env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0);
1949   for (int i = 0; i < numCtx; i++) {
1950     jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
1951     env->ReleaseStringUTFChars(jkey, mapKeys[i]);
1952     env->DeleteLocalRef(jkey);
1953   }
1954   delete[] mapKeys;
1955 
1956   SetLongField(env, jexecOut, reinterpret_cast<jlong>(out));
1957   return ret;
1958 }
1959 
Java_org_apache_mxnet_LibInfo_mxRandomSeed(JNIEnv * env,jobject obj,jint seed)1960 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRandomSeed
1961   (JNIEnv *env, jobject obj, jint seed) {
1962   return MXRandomSeed(seed);
1963 }
1964 
Java_org_apache_mxnet_LibInfo_mxNotifyShutdown(JNIEnv * env,jobject obj)1965 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNotifyShutdown
1966   (JNIEnv *env, jobject obj) {
1967   return MXNotifyShutdown();
1968 }
1969 
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterCreate(JNIEnv * env,jobject obj,jstring juri,jobject handle)1970 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterCreate
1971   (JNIEnv *env, jobject obj, jstring juri, jobject handle) {
1972   RecordIOHandle out;
1973   const char *uri = env->GetStringUTFChars(juri, 0);
1974   int ret = MXRecordIOWriterCreate(uri, &out);
1975   env->ReleaseStringUTFChars(juri, uri);
1976   SetLongField(env, handle, reinterpret_cast<jlong>(out));
1977   return ret;
1978 }
1979 
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderCreate(JNIEnv * env,jobject obj,jstring juri,jobject handle)1980 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderCreate
1981   (JNIEnv *env, jobject obj, jstring juri, jobject handle) {
1982   RecordIOHandle out;
1983   const char *uri = env->GetStringUTFChars(juri, 0);
1984   int ret = MXRecordIOReaderCreate(uri, &out);
1985   env->ReleaseStringUTFChars(juri, uri);
1986   SetLongField(env, handle, reinterpret_cast<jlong>(out));
1987   return ret;
1988 }
1989 
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterFree(JNIEnv * env,jobject obj,jlong handle)1990 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterFree
1991   (JNIEnv *env, jobject obj, jlong handle) {
1992   RecordIOHandle recordIOHandle = reinterpret_cast<RecordIOHandle>(handle);
1993   int ret = MXRecordIOWriterFree(recordIOHandle);
1994   return ret;
1995 }
1996 
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderFree(JNIEnv * env,jobject obj,jlong handle)1997 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderFree
1998   (JNIEnv *env, jobject obj, jlong handle) {
1999   RecordIOHandle recordIOHandle = reinterpret_cast<RecordIOHandle>(handle);
2000   int ret = MXRecordIOReaderFree(&recordIOHandle);
2001   return ret;
2002 }
2003 
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterWriteRecord(JNIEnv * env,jobject obj,jlong handle,jstring jbuf,jint size)2004 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterWriteRecord
2005   (JNIEnv *env, jobject obj, jlong handle, jstring jbuf, jint size) {
2006   const char *buf = env->GetStringUTFChars(jbuf, 0);
2007   RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2008   int ret = MXRecordIOWriterWriteRecord(recordIOHandle, buf, size);
2009   env->ReleaseStringUTFChars(jbuf, buf);
2010   return ret;
2011 }
2012 
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderReadRecord(JNIEnv * env,jobject obj,jlong handle,jobject buf)2013 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderReadRecord
2014   (JNIEnv *env, jobject obj, jlong handle, jobject buf) {
2015   RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2016   size_t size;
2017   char const  *out;
2018   int ret = MXRecordIOReaderReadRecord(recordIOHandle, &out, &size);
2019   SetStringField(env, buf, out);
2020   return ret;
2021 }
2022 
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterTell(JNIEnv * env,jobject obj,jlong handle,jobject jpos)2023 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterTell
2024   (JNIEnv *env, jobject obj, jlong handle, jobject jpos) {
2025   RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2026   size_t pos;
2027   int ret = MXRecordIOWriterTell(recordIOHandle, &pos);
2028   SetIntField(env, jpos, pos);
2029   return ret;
2030 }
2031 
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderSeek(JNIEnv * env,jobject obj,jlong handle,jint pos)2032 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderSeek
2033   (JNIEnv *env, jobject obj, jlong handle, jint pos) {
2034   RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2035   int ret = MXRecordIOReaderSeek(recordIOHandle, pos);
2036   return ret;
2037 }
2038 
Java_org_apache_mxnet_LibInfo_mxRtcCreate(JNIEnv * env,jobject obj,jstring jname,jobjectArray jinputNames,jobjectArray joutputNames,jlongArray jinputs,jlongArray joutputs,jstring jkernel,jobject jhandle)2039 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcCreate
2040   (JNIEnv *env, jobject obj, jstring jname, jobjectArray jinputNames,
2041     jobjectArray joutputNames, jlongArray jinputs, jlongArray joutputs,
2042     jstring jkernel, jobject jhandle) {
2043   RtcHandle out;
2044   char *name = const_cast<char *>(env->GetStringUTFChars(jname, 0));
2045   int num_input = env->GetArrayLength(jinputNames);
2046   char **inputNames = new char *[num_input];
2047   for (int i = 0; i < num_input; i++) {
2048     jstring jinname = reinterpret_cast<jstring>(env->GetObjectArrayElement(jinputNames, i));
2049     char *inname = const_cast<char *>(env->GetStringUTFChars(jinname, 0));
2050     inputNames[i] = inname;
2051     env->DeleteLocalRef(jinname);
2052   }
2053   int num_output = env->GetArrayLength(joutputNames);
2054   char **outputNames = new char *[num_output];
2055   for (int i = 0; i < num_output; i++) {
2056     jstring joutname = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputNames, i));
2057     char *outname = const_cast<char *>(env->GetStringUTFChars(joutname, 0));
2058     outputNames[i] = outname;
2059     env->DeleteLocalRef(joutname);
2060   }
2061   jlong *inputs = env->GetLongArrayElements(jinputs, NULL);
2062   jlong *outputs = env->GetLongArrayElements(joutputs, NULL);
2063   char *kernel = const_cast<char *>(env->GetStringUTFChars(jkernel, 0));
2064 
2065   int ret = MXRtcCreate(name,
2066                         static_cast<mx_uint>(num_input),
2067                         static_cast<mx_uint>(num_output),
2068                         inputNames,
2069                         outputNames,
2070                         reinterpret_cast<NDArrayHandle *>(inputs),
2071                         reinterpret_cast<NDArrayHandle *>(outputs),
2072                         kernel,
2073                         &out);
2074 
2075   // release allocated memory
2076   env->ReleaseStringUTFChars(jname, name);
2077   env->ReleaseStringUTFChars(jkernel, kernel);
2078   env->ReleaseLongArrayElements(jinputs, inputs, 0);
2079   env->ReleaseLongArrayElements(joutputs, outputs, 0);
2080   for (int i = 0; i < num_input; i++) {
2081     jstring jinname = reinterpret_cast<jstring>(env->GetObjectArrayElement(jinputNames, i));
2082     env->ReleaseStringUTFChars(jinname, inputNames[i]);
2083     env->DeleteLocalRef(jinname);
2084   }
2085   delete[] inputNames;
2086   for (int i = 0; i < num_output; i++) {
2087     jstring joutname = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputNames, i));
2088     env->ReleaseStringUTFChars(joutname, outputNames[i]);
2089     env->DeleteLocalRef(joutname);
2090   }
2091   delete[] outputNames;
2092 
2093   SetLongField(env, jhandle, reinterpret_cast<jlong>(out));
2094   return ret;
2095 }
2096 
Java_org_apache_mxnet_LibInfo_mxRtcPush(JNIEnv * env,jobject obj,jlong jhandle,jlongArray jinputs,jlongArray joutputs,jint gridDimX,jint gridDimY,jint gridDimZ,jint blockDimX,jint blockDimY,jint blockDimZ)2097 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcPush
2098   (JNIEnv *env, jobject obj, jlong jhandle, jlongArray jinputs,
2099     jlongArray joutputs, jint gridDimX, jint gridDimY, jint gridDimZ,
2100     jint blockDimX, jint blockDimY, jint blockDimZ) {
2101 
2102   RtcHandle handle = reinterpret_cast<RtcHandle>(jhandle);
2103   jlong *inputs = env->GetLongArrayElements(jinputs, NULL);
2104   jlong *outputs = env->GetLongArrayElements(joutputs, NULL);
2105   int num_input = env->GetArrayLength(jinputs);
2106   int num_output = env->GetArrayLength(joutputs);
2107 
2108   int ret = MXRtcPush(handle,
2109                       static_cast<mx_uint>(num_input),
2110                       static_cast<mx_uint>(num_output),
2111                       reinterpret_cast<NDArrayHandle *>(inputs),
2112                       reinterpret_cast<NDArrayHandle *>(outputs),
2113                       static_cast<mx_uint>(gridDimX),
2114                       static_cast<mx_uint>(gridDimY),
2115                       static_cast<mx_uint>(gridDimZ),
2116                       static_cast<mx_uint>(blockDimX),
2117                       static_cast<mx_uint>(blockDimY),
2118                       static_cast<mx_uint>(blockDimZ));
2119 
2120   // release allocated memory
2121   env->ReleaseLongArrayElements(jinputs, inputs, 0);
2122   env->ReleaseLongArrayElements(joutputs, outputs, 0);
2123 
2124   return ret;
2125 }
2126 
Java_org_apache_mxnet_LibInfo_mxRtcFree(JNIEnv * env,jobject obj,jlong jhandle)2127 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcFree
2128   (JNIEnv *env, jobject obj, jlong jhandle) {
2129   RtcHandle handle = reinterpret_cast<RtcHandle>(jhandle);
2130   int ret = MXRtcFree(handle);
2131   return ret;
2132 }
2133 
2134 // store the user defined CustomOpProp object reference with its name
2135 std::unordered_map<std::string, jobject> globalOpPropMap;
2136 // store the user defined CustomOp object reference with its name
2137 std::unordered_map<std::string, jobject> globalOpMap;
2138 // used for thread safty when insert  elements into
2139 // or erase elements from the std::unordered_map
2140 std::mutex mutex_opprop;
2141 std::mutex mutex_op;
2142 
2143 // Registers a custom operator when called
Java_org_apache_mxnet_LibInfo_mxCustomOpRegister(JNIEnv * env,jobject obj,jstring jregName,jobject jopProp)2144 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
2145   (JNIEnv *env, jobject obj, jstring jregName, jobject jopProp) {
2146   const char *regName = env->GetStringUTFChars(jregName, 0);
2147   std::string key(regName);
2148 
2149   std::unique_lock<std::mutex> lock(mutex_opprop);
2150   globalOpPropMap.insert({ key, env->NewGlobalRef(jopProp) });
2151   lock.unlock();
2152 
2153   // lambda function to initialize the operator and create all callbacks
2154   auto creatorLambda = [](const char *opType, const int numKwargs,
2155     const char  **keys, const char **values, MXCallbackList *ret) {
2156     int success = true;
2157 
2158     std::string opPropKey(opType);
2159     if (globalOpPropMap.find(opPropKey) == globalOpPropMap.end()) {
2160       LOG(WARNING) << "CustomOpProp: " << opPropKey << " not found";
2161       success = false;
2162     } else {
2163       JNIEnv *env;
2164       _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2165       jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(opPropKey));
2166       jmethodID midInit = env->GetMethodID(opPropClass,
2167         "init", "([Ljava/lang/String;[Ljava/lang/String;)V");
2168       if (NULL == midInit) {
2169         LOG(WARNING) << "could not find CustomOpProp method init.";
2170         success = false;
2171       } else {
2172         // call init and set CustomOpProp.kwargs
2173         jclass strCls = env->FindClass("Ljava/lang/String;");
2174         jobjectArray keysArr = env->NewObjectArray(numKwargs, strCls, NULL);
2175         jobjectArray valuesArr = env->NewObjectArray(numKwargs, strCls, NULL);
2176         for (int i = 0; i < numKwargs; ++i) {
2177           jstring keyStr = env->NewStringUTF(keys[i]);
2178           jstring valueStr = env->NewStringUTF(values[i]);
2179           env->SetObjectArrayElement(keysArr, i, keyStr);
2180           env->SetObjectArrayElement(valuesArr, i, valueStr);
2181           env->DeleteLocalRef(keyStr);
2182           env->DeleteLocalRef(valueStr);
2183         }
2184         env->CallVoidMethod(globalOpPropMap.at(opPropKey), midInit, keysArr, valuesArr);
2185         env->DeleteLocalRef(keysArr);
2186         env->DeleteLocalRef(valuesArr);
2187       }
2188       _jvm->DetachCurrentThread();
2189     }
2190 
2191     // list_arguments callback
2192     auto opPropListArgument = [](char ***args, void *state) {
2193       int success = true;
2194       std::string key(reinterpret_cast<char *>(state));
2195       if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2196         LOG(WARNING) << "CustomOpProp: " << key << " not found";
2197         success = false;
2198       } else {
2199         JNIEnv *env;
2200         _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2201         jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2202         jmethodID midListArguments = env->GetMethodID(
2203           opPropClass, "listArguments", "()[Ljava/lang/String;");
2204         if (NULL == midListArguments) {
2205           LOG(WARNING) << "could not find opProp method listArguments.";
2206           success = false;
2207         } else {
2208           jobjectArray jargs =(jobjectArray)(env->CallObjectMethod(
2209             globalOpPropMap.at(key), midListArguments));
2210           int len = env->GetArrayLength(jargs);
2211           *args = new char *[len+1];
2212           for (int i = 0; i < len; ++i) {
2213             jstring jarg = reinterpret_cast<jstring>(env->GetObjectArrayElement(jargs, i));
2214             const char *arg = env->GetStringUTFChars(jarg, 0);
2215             (*args)[i] = const_cast<char *>(arg);
2216             env->DeleteLocalRef(jarg);
2217           }
2218           (*args)[len] = NULL;
2219         }
2220         _jvm->DetachCurrentThread();
2221       }
2222       return success;
2223     };
2224 
2225     // list_outputs callback
2226     auto opPropListOutputs = [](char ***outputs, void *state) {
2227       int success = true;
2228       std::string key(reinterpret_cast<char *>(state));
2229       if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2230         LOG(WARNING) << "CustomOpProp: " << key << " not found";
2231         success = false;
2232       } else {
2233         JNIEnv *env;
2234         _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2235         jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2236         jmethodID midListOutputs = env->GetMethodID(
2237           opPropClass, "listOutputs", "()[Ljava/lang/String;");
2238         if (NULL == midListOutputs) {
2239           LOG(WARNING) << "could not find opProp method listOutputs.";
2240           success = false;
2241         } else {
2242           jobjectArray joutputs = (jobjectArray)(env->CallObjectMethod(
2243             globalOpPropMap.at(key), midListOutputs));
2244           int len = env->GetArrayLength(joutputs);
2245           *outputs = new char *[len + 1];
2246           for (int i = 0; i < len; ++i) {
2247             jstring joutput = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputs, i));
2248             const char *output = env->GetStringUTFChars(joutput, 0);
2249             (*outputs)[i] = const_cast<char *>(output);
2250             env->DeleteLocalRef(joutput);
2251           }
2252           (*outputs)[len] = NULL;
2253         }
2254         _jvm->DetachCurrentThread();
2255       }
2256       return success;
2257     };
2258 
2259     // list_auxiliary_states callback
2260     auto opPropListAuxStates = [](char ***auxs, void *state) {
2261       int success = true;
2262       std::string key(reinterpret_cast<char *>(state));
2263       if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2264         LOG(WARNING) << "CustomOpProp: " << key << " not found";
2265         success = false;
2266       } else {
2267         JNIEnv *env;
2268         _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2269         jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2270         jmethodID midListAuxStates = env->GetMethodID(
2271           opPropClass, "listAuxiliaryStates", "()[Ljava/lang/String;");
2272         if (NULL == midListAuxStates) {
2273           LOG(WARNING) << "could not find opProp method listAuxiliaryStates.";
2274           success = false;
2275         } else {
2276           auto obj = env->CallObjectMethod(globalOpPropMap.at(key), midListAuxStates);
2277           if (obj != NULL) {
2278             jobjectArray jauxs = (jobjectArray)obj;
2279             int len = env->GetArrayLength(jauxs);
2280             *auxs = new char *[len+1];
2281             for (int i = 0; i < len; ++i) {
2282               jstring jaux = reinterpret_cast<jstring>(env->GetObjectArrayElement(jauxs, i));
2283               const char *aux = env->GetStringUTFChars(jaux, 0);
2284               (*auxs)[i] = const_cast<char *>(aux);
2285               env->DeleteLocalRef(jaux);
2286             }
2287             (*auxs)[len] = NULL;
2288           } else {
2289             (*auxs) = new char *[1];
2290             (*auxs)[0] = NULL;
2291           }
2292         }
2293         _jvm->DetachCurrentThread();
2294       }
2295       return success;
2296     };
2297 
2298     // declare_backward_dependency callback
2299     auto opPropDeclareBkDep = [](const int *outGrad, const int *inData,
2300       const int *outData, int *numDeps, int **rdeps, void *state) {
2301       int success = true;
2302       std::string key(reinterpret_cast<char *>(state));
2303       if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2304         LOG(WARNING) << "CustomOpProp: " << key << " not found";
2305         success = false;
2306       } else {
2307         JNIEnv *env;
2308         _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2309         jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2310         jmethodID midDeclareBkDep = env->GetMethodID(
2311           opPropClass, "declareBackwardDependency", "([I[I[I)[I");
2312         if (NULL == midDeclareBkDep) {
2313           LOG(WARNING) << "could not find opProp method declareBackwardDependency.";
2314           success = false;
2315         } else {
2316           jmethodID midListOutputs = env->GetMethodID(
2317             opPropClass, "listOutputs", "()[Ljava/lang/String;");
2318           jobjectArray joutputs = (jobjectArray)(env->CallObjectMethod(
2319             globalOpPropMap.at(key), midListOutputs));
2320           int outLen = env->GetArrayLength(joutputs);
2321           jmethodID midListArguments = env->GetMethodID(
2322             opPropClass, "listArguments", "()[Ljava/lang/String;");
2323           jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
2324             globalOpPropMap.at(key), midListArguments));
2325           int intLen = env->GetArrayLength(jargs);
2326 
2327           jintArray outGradArr = env->NewIntArray(outLen);
2328           env->SetIntArrayRegion(outGradArr, (jsize)0, (jsize)outLen, outGrad);
2329           jintArray inDataArr = env->NewIntArray(intLen);
2330           env->SetIntArrayRegion(inDataArr, (jsize)0, (jsize)intLen, inData);
2331           jintArray outDataArr = env->NewIntArray(outLen);
2332           env->SetIntArrayRegion(outDataArr, (jsize)0, (jsize)outLen, outData);
2333 
2334           auto obj = env->CallObjectMethod(globalOpPropMap.at(key), midDeclareBkDep,
2335                                                    outGradArr,
2336                                                    inDataArr,
2337                                                    outDataArr);
2338           jintArray jrdeps = (jintArray)obj;
2339           jint *rdepsArr = env->GetIntArrayElements(jrdeps, NULL);
2340 
2341           *numDeps = env->GetArrayLength(jrdeps);
2342           *rdeps = new int[(* numDeps)];
2343           for (int i = 0 ; i < (*numDeps); ++i) {
2344             (*rdeps)[i] = rdepsArr[i];
2345           }
2346           env->DeleteLocalRef(outGradArr);
2347           env->DeleteLocalRef(inDataArr);
2348           env->DeleteLocalRef(outDataArr);
2349           env->ReleaseIntArrayElements(jrdeps, rdepsArr, 0);
2350         }
2351         _jvm->DetachCurrentThread();
2352       }
2353       return success;
2354     };
2355 
2356     // infer_shape callback
2357     auto opPropInferShape = [](int numInput, int *ndims,
2358       unsigned **shapes, void *state) {
2359       int success = true;
2360       std::string key(reinterpret_cast<char *>(state));
2361       if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2362         LOG(WARNING) << "CustomOpProp: " << key << " not found";
2363         success = false;
2364       } else {
2365         JNIEnv *env;
2366         _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2367         jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2368         jmethodID midInferShape = env->GetMethodID(opPropClass, "inferShapeEntry", "(I[[I)[[I");
2369         if (NULL == midInferShape) {
2370           LOG(WARNING) << "could not find opProp method inferShapeEntry.";
2371           success = false;
2372         } else {
2373           jmethodID midListArguments = env->GetMethodID(
2374             opPropClass, "listArguments", "()[Ljava/lang/String;");
2375           jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
2376             globalOpPropMap.at(key), midListArguments));
2377           int intLen = env->GetArrayLength(jargs);
2378           jintArray *ts = new jintArray[intLen];
2379           auto tmp = env->NewIntArray(1);
2380           jclass arrayClass = env->GetObjectClass(tmp);
2381           env->DeleteLocalRef(tmp);
2382           jobjectArray tensorShapes = env->NewObjectArray(intLen, arrayClass, NULL);
2383           for (int i = 0; i < intLen; ++i) {
2384             ts[i] = env->NewIntArray(ndims[i]);
2385             env->SetIntArrayRegion(
2386               ts[i], (jsize)0, (jsize)ndims[i], reinterpret_cast<int *>(shapes[i]));
2387             env->SetObjectArrayElement(tensorShapes, i, (jobject)(ts[i]));
2388           }
2389           jobjectArray ret = (jobjectArray)(env->CallObjectMethod(
2390             globalOpPropMap.at(key), midInferShape,
2391             numInput,
2392             tensorShapes));
2393           for (int i = 0; i < numInput; ++i) {
2394             jintArray jarr = reinterpret_cast<jintArray>(env->GetObjectArrayElement(ret, i));
2395             int len = env->GetArrayLength(jarr);
2396             jint *arr = env->GetIntArrayElements(jarr, NULL);
2397             ndims[i] = len;
2398             shapes[i] = new unsigned[len];
2399             for (int j = 0; j < len; ++j) shapes[i][j] = (unsigned)(arr[j]);
2400             env->DeleteLocalRef(jarr);
2401           }
2402           for (int i = 0; i < intLen; ++i) {
2403             env->DeleteLocalRef(ts[i]);
2404           }
2405           delete[] ts;
2406         }
2407         _jvm->DetachCurrentThread();
2408       }
2409       return success;
2410     };
2411 
2412     // infer_type callback
2413     auto opPropInferType = [](int numInput, int* types, void* state) {
2414       int success = true;
2415       std::string key(reinterpret_cast<char *>(state));
2416       if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2417         LOG(WARNING) << "CustomOpProp: " << key << " not found";
2418         success = false;
2419       } else {
2420         JNIEnv *env;
2421         _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2422         jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2423         jmethodID midInferType = env->GetMethodID(opPropClass, "inferTypeEntry", "(I[I)[I");
2424         if (NULL == midInferType) {
2425           LOG(WARNING) << "could not find opProp method inferTypeEntry.";
2426           success = false;
2427         } else {
2428           jmethodID midListArguments = env->GetMethodID(
2429             opPropClass, "listArguments", "()[Ljava/lang/String;");
2430           jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
2431             globalOpPropMap.at(key), midListArguments));
2432 
2433           int intLen = env->GetArrayLength(jargs);
2434           jintArray ts = env->NewIntArray(intLen);
2435           int *tmp = new int[intLen];
2436           for (int i = 0; i < intLen; ++i) tmp[i] = types[i];
2437           env->SetIntArrayRegion(ts, (jsize)0, (jsize)intLen, tmp);
2438 
2439           jintArray ret = (jintArray)(env->CallObjectMethod(
2440             globalOpPropMap.at(key), midInferType,
2441             numInput,
2442             ts));
2443           jint *arr = env->GetIntArrayElements(ret, NULL);
2444           for (int i = 0; i < numInput; ++i) {
2445             types[i] = static_cast<int>(arr[i]);
2446           }
2447 
2448           delete[] tmp;
2449           env->ReleaseIntArrayElements(ret, arr, 0);
2450           env->DeleteLocalRef(ret);
2451           env->DeleteLocalRef(ts);
2452         }
2453         _jvm->DetachCurrentThread();
2454       }
2455       return success;
2456     };
2457 
2458     // create_operator callback
2459     auto opPropCreateOp = [](const char *ctx, int numInputs,
2460       unsigned **shapes, int *ndims, int *dtypes, MXCallbackList *ret, void *state) {
2461       int success = true;
2462       std::string key(reinterpret_cast<char *>(state));
2463       if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2464         LOG(WARNING) << "CustomOpProp: " << key << " not found";
2465         success = false;
2466       } else {
2467         JNIEnv *env;
2468         _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2469         jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2470         jmethodID midCreateOp = env->GetMethodID(
2471           opPropClass, "createOperator", "(Ljava/lang/String;[[I[I)Lorg/apache/mxnet/CustomOp;");
2472         if (NULL == midCreateOp) {
2473           LOG(WARNING) << "could not find opProp method createOperator.";
2474           success = false;
2475         } else {
2476           jstring jctx = env->NewStringUTF(ctx);
2477           jintArray *ts = new jintArray[numInputs];
2478           auto tmp = env->NewIntArray(1);
2479           jclass arrayClass = env->GetObjectClass(tmp);
2480           env->DeleteLocalRef(tmp);
2481           jobjectArray inputShapes = env->NewObjectArray(numInputs, arrayClass, NULL);
2482           for (int i = 0; i < numInputs; ++i) {
2483             ts[i] = env->NewIntArray(ndims[i]);
2484             env->SetIntArrayRegion(
2485               ts[i], (jsize)0, (jsize)ndims[i], reinterpret_cast<int *>(shapes[i]));
2486             env->SetObjectArrayElement(inputShapes, i, (jobject)(ts[i]));
2487           }
2488           jintArray jdtypes = env->NewIntArray(numInputs);
2489           env->SetIntArrayRegion(jdtypes, (jsize)0, (jsize)numInputs, dtypes);
2490           // get operator
2491           jobject jOp = env->CallObjectMethod(globalOpPropMap.at(key), midCreateOp,
2492                                         jctx,
2493                                         inputShapes,
2494                                         jdtypes);
2495           env->DeleteLocalRef(jctx);
2496           for (int i = 0; i < numInputs; ++i) {
2497             env->DeleteLocalRef(ts[i]);
2498           }
2499           delete[] ts;
2500 
2501           std::unique_lock<std::mutex> lock(mutex_op);
2502           globalOpMap.insert({ key, env->NewGlobalRef(jOp) });
2503           lock.unlock();
2504 
2505           _jvm->DetachCurrentThread();
2506 
2507           // forward callback
2508           auto forwardEntry = [](int size, void **ptrs, int *tags,
2509             const int *reqs, const int isTrain, void *state) {
2510             std::string key(reinterpret_cast<char *>(state));
2511             int success = true;
2512             if (globalOpMap.find(key) == globalOpMap.end()) {
2513               LOG(WARNING) << "op: " << key << " not found";
2514               success = false;
2515             } else {
2516               JNIEnv *env;
2517               _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2518               jclass opClass =  env->GetObjectClass(globalOpMap.at(key));
2519               jmethodID midForward = env->GetMethodID(opClass, "forwardEntry", "(I[J[I[IZ)Z");
2520               if (NULL == midForward) {
2521                 LOG(WARNING) << "could not find op method forwardEntry.";
2522                 success = false;
2523               } else {
2524                 jintArray tagsArr = env->NewIntArray(size);
2525                 env->SetIntArrayRegion(tagsArr, (jsize)0, (jsize)size, tags);
2526                 int reqSize = 0;
2527                 for (int i = 0; i < size; ++i) {
2528                   if (tags[i] == 1) reqSize++;
2529                 }
2530                 jintArray reqsArr = env->NewIntArray(reqSize);
2531                 env->SetIntArrayRegion(reqsArr, (jsize)0, (jsize)reqSize, reqs);
2532                 jlongArray ptrsArr = env->NewLongArray(size);
2533                 env->SetLongArrayRegion(
2534                   ptrsArr, (jsize)0, (jsize)size, reinterpret_cast<jlong*>(ptrs));
2535 #if MXNET_USE_CUDA
2536                 mxnet::NDArray* tmp = reinterpret_cast<mxnet::NDArray*>(ptrs[0]);
2537                 if (tmp->ctx().dev_type == mxnet::Context::kGPU
2538                   || tmp->ctx().dev_type == mxnet::Context::kCPUPinned) {
2539                   CUDA_CALL(cudaSetDevice(tmp->ctx().dev_id));
2540                 }
2541 #endif
2542                 bool is_train =  true;
2543                 if (isTrain == 0) is_train = false;
2544                 success = env->CallBooleanMethod(globalOpMap.at(key), midForward,
2545                                                        size,
2546                                                        ptrsArr,
2547                                                        tagsArr,
2548                                                        reqsArr,
2549                                                        is_train);
2550                 env->DeleteLocalRef(tagsArr);
2551                 env->DeleteLocalRef(reqsArr);
2552                 env->DeleteLocalRef(ptrsArr);
2553               }
2554               _jvm->DetachCurrentThread();
2555             }
2556             return success;
2557           };
2558 
2559           // backward callback
2560           auto backwardEntry = [](int size, void **ptrs, int *tags,
2561             const int *reqs, const int isTrain, void *state) {
2562             std::string key(reinterpret_cast<char *>(state));
2563             int success = true;
2564             if (globalOpMap.find(key) == globalOpMap.end()) {
2565               LOG(WARNING) << "op: " << key << " not found";
2566               success = false;
2567             } else {
2568               JNIEnv *env;
2569               _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2570               jclass opClass = env->GetObjectClass(globalOpMap.at(key));
2571               jmethodID midBackward = env->GetMethodID(opClass, "backwardEntry", "(I[J[I[IZ)Z");
2572               if (NULL == midBackward) {
2573                 LOG(WARNING) << "could not find op method backwardEntry.";
2574                 success = false;
2575               } else {
2576                 jintArray tagsArr = env->NewIntArray(size);
2577                 env->SetIntArrayRegion(tagsArr, (jsize)0, (jsize)size, tags);
2578 
2579                 int reqSize = 0;
2580                 for (int i = 0; i < size; ++i) {
2581                   if (tags[i] == 2) reqSize++;
2582                 }
2583                 jintArray reqsArr = env->NewIntArray(reqSize);
2584                 env->SetIntArrayRegion(reqsArr, (jsize)0, (jsize)reqSize, reqs);
2585                 jlongArray ptrsArr = env->NewLongArray(size);
2586                 env->SetLongArrayRegion(
2587                   ptrsArr, (jsize)0, (jsize)size, reinterpret_cast<jlong*>(ptrs));
2588                 bool is_train =  true;
2589                 if (isTrain == 0) is_train = false;
2590                 success = env->CallBooleanMethod(globalOpMap.at(key), midBackward,
2591                                                        size,
2592                                                        ptrsArr,
2593                                                        tagsArr,
2594                                                        reqsArr,
2595                                                        is_train);
2596                 env->DeleteLocalRef(tagsArr);
2597                 env->DeleteLocalRef(reqsArr);
2598                 env->DeleteLocalRef(ptrsArr);
2599               }
2600               _jvm->DetachCurrentThread();
2601             }
2602             return success;
2603           };
2604 
2605           // del callback
2606           auto delEntry = [](void *state) {
2607             std::string key(reinterpret_cast<char *>(state));
2608             int success = true;
2609             std::unique_lock<std::mutex> lock(mutex_op);
2610             if (globalOpMap.find(key) == globalOpMap.end()) {
2611               LOG(WARNING) << "op: " << key << " not found";
2612               success = false;
2613             } else {
2614               JNIEnv *env;
2615               _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2616               env->DeleteGlobalRef(globalOpMap.at(key));
2617               _jvm->DetachCurrentThread();
2618               for (auto it = globalOpMap.begin(); it != globalOpMap.end(); ) {
2619                 if (it->first == key) {
2620                   it = globalOpMap.erase(it);
2621                 } else {
2622                   ++it;
2623                 }
2624               }
2625             }
2626             lock.unlock();
2627             return success;
2628           };
2629 
2630           // TODO(eric): Memory leak here. Refactor later and delete in delEntry
2631           ret->num_callbacks = 3;
2632           ret->callbacks = new MXGenericCallback[ret->num_callbacks];
2633           ret->callbacks[kCustomOpDelete] =
2634             reinterpret_cast<int(*)(void)>(static_cast<int(*)(void*)>(delEntry));
2635           ret->callbacks[kCustomOpForward] =
2636             reinterpret_cast<int(*)(void)>(
2637               static_cast<int(*)(int, void**, int*, const int*, const int, void*)>(
2638                 forwardEntry));
2639           ret->callbacks[kCustomOpBackward] =
2640             reinterpret_cast<int(*)(void)>(
2641               static_cast<int(*)(int, void**, int*, const int*, const int, void*)>(
2642                 backwardEntry));
2643           ret->contexts = new void*[ret->num_callbacks];
2644           ret->contexts[kCustomOpDelete] = state;
2645           ret->contexts[kCustomOpForward] = state;
2646           ret->contexts[kCustomOpBackward] = state;
2647         }
2648       }
2649       return success;
2650     };
2651 
2652     // del callback
2653     auto opPropDel = [](void *state) {
2654       /*
2655        * This method seems to be called by the engine to clean up after multiple calls were made
2656        * to the creator lambda. The current creator function isn't allocating a new object but is
2657        * instead reinitializing the object which was created when register was called. This means
2658        * that there doesn't seem to be anything to clean up here (previous efforts were actually
2659        * deregistering the operator).
2660       */
2661       return 1;
2662     };
2663 
2664     // TODO(eric): Memory leak. Missing infertype.
2665     ret->num_callbacks = 8;
2666     ret->callbacks = new MXGenericCallback[ret->num_callbacks];
2667     ret->callbacks[kCustomOpPropDelete] =
2668       reinterpret_cast<int(*)(void)>(
2669         static_cast<int(*)(void*)>(opPropDel));
2670     ret->callbacks[kCustomOpPropListArguments] =
2671       reinterpret_cast<int(*)(void)>(
2672         static_cast<int(*)(char***, void*)>(opPropListArgument));
2673     ret->callbacks[kCustomOpPropListOutputs] =
2674       reinterpret_cast<int(*)(void)>(
2675         static_cast<int(*)(char***, void*)>(opPropListOutputs));
2676     ret->callbacks[kCustomOpPropListAuxiliaryStates] =
2677       reinterpret_cast<int(*)(void)>(
2678         static_cast<int(*)(char***, void*)>(opPropListAuxStates));
2679     ret->callbacks[kCustomOpPropInferShape] =
2680       reinterpret_cast<int(*)(void)>(
2681         static_cast<int (*)(int, int*, unsigned**, void*)>(opPropInferShape));
2682     ret->callbacks[kCustomOpPropDeclareBackwardDependency] =
2683       reinterpret_cast<int(*)(void)>(
2684         static_cast<int(*)(const int*, const int*, const int*, int* num_deps, int**, void*)>(
2685           opPropDeclareBkDep));
2686     ret->callbacks[kCustomOpPropCreateOperator] =
2687       reinterpret_cast<int(*)(void)>(
2688         static_cast<int(*)(const char*, int, unsigned**, int*, int*, MXCallbackList*, void*)>(
2689           opPropCreateOp));
2690     ret->callbacks[kCustomOpPropInferType] =
2691       reinterpret_cast<int(*)(void)>(
2692         static_cast<int(*)(int, int*, void*)>(opPropInferType));
2693 
2694     ret->contexts = new void*[ret->num_callbacks];
2695     ret->contexts[kCustomOpPropDelete] =
2696       reinterpret_cast<void *>(const_cast<char *>(opType));
2697     ret->contexts[kCustomOpPropListArguments] =
2698       reinterpret_cast<void *>(const_cast<char *>(opType));
2699     ret->contexts[kCustomOpPropListOutputs] =
2700       reinterpret_cast<void *>(const_cast<char *>(opType));
2701     ret->contexts[kCustomOpPropListAuxiliaryStates] =
2702       reinterpret_cast<void *>(const_cast<char *>(opType));
2703     ret->contexts[kCustomOpPropInferShape] =
2704       reinterpret_cast<void *>(const_cast<char *>(opType));
2705     ret->contexts[kCustomOpPropDeclareBackwardDependency] =
2706       reinterpret_cast<void *>(const_cast<char *>(opType));
2707     ret->contexts[kCustomOpPropCreateOperator] =
2708       reinterpret_cast<void *>(const_cast<char *>(opType));
2709     ret->contexts[kCustomOpPropInferType] =
2710       reinterpret_cast<void *>(const_cast<char *>(opType));
2711     return success;
2712   };
2713 
2714   CustomOpPropCreator creator =
2715     static_cast<int(*)(const char*, const int, const char**, const char**, MXCallbackList*)>(
2716       creatorLambda);
2717   return MXCustomOpRegister(regName, creator);
2718 }
2719 
2720 struct JNIString {
2721   JNIEnv *env_;
2722   jstring java_string_;
2723   const char *str_;
JNIStringJNIString2724   inline JNIString(JNIEnv *env, const jstring& java_string)
2725   : env_(env)
2726     , java_string_(java_string) {
2727     str_ = env_->GetStringUTFChars(java_string_, 0);
2728   }
~JNIStringJNIString2729   inline ~JNIString() {
2730     if (str_) {
2731       env_->ReleaseStringUTFChars(java_string_, str_);
2732     }
2733   }
operator ()JNIString2734   inline const char *operator ()() const {
2735     return str_;
2736   }
2737 };
2738 
2739 struct JNIStringArray {
2740   std::vector<std::unique_ptr<JNIString>> jni_strings_;
2741   std::vector<const char *> strings_;
JNIStringArrayJNIStringArray2742   JNIStringArray(JNIEnv *env, const jobjectArray& stringArray) {
2743     const int count = env->GetArrayLength(stringArray);
2744     jni_strings_.reserve(count);
2745     strings_.reserve(count);
2746     for (int i = 0; i < count; ++i) {
2747       jstring string = static_cast<jstring>(env->GetObjectArrayElement(stringArray, i));
2748       jni_strings_.emplace_back(std::unique_ptr<JNIString>(new JNIString(env, string)));
2749       strings_.emplace_back((*jni_strings_.rbegin())->str_);
2750     }
2751   }
operator ()JNIStringArray2752   const char * const* operator ()() const { return &strings_[0]; }
2753 };
2754 
Java_org_apache_mxnet_LibInfo_mxSetProfilerConfig(JNIEnv * env,jobject obj,jobjectArray keys,jobjectArray vals)2755 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetProfilerConfig
2756   (JNIEnv *env, jobject obj, jobjectArray keys, jobjectArray vals) {
2757   const int stringCount = env->GetArrayLength(keys);
2758   CHECK_EQ(stringCount, env->GetArrayLength(vals)) << "Key and value arrays must be the same size";
2759 
2760   JNIStringArray the_keys(env, keys), the_vals(env, vals);
2761 
2762   const int ret = MXSetProfilerConfig(stringCount, the_keys(), the_vals());
2763   return ret;
2764 }
2765 
Java_org_apache_mxnet_LibInfo_mxSetProfilerState(JNIEnv * env,jobject obj,jint jstate)2766 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetProfilerState
2767   (JNIEnv *env, jobject obj, jint jstate) {
2768   return MXSetProfilerState(jstate);
2769 }
2770 
Java_org_apache_mxnet_LibInfo_mxDumpProfile(JNIEnv * env,jobject obj,jint finished)2771 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile
2772   (JNIEnv *env, jobject obj, jint finished) {
2773   return MXDumpProfile(finished);
2774 }
2775 
2776 // Numpy
Java_org_apache_mxnet_LibInfo_mxIsNumpyShape(JNIEnv * env,jobject obj,jobject compatibleRef)2777 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape
2778   (JNIEnv *env, jobject obj, jobject compatibleRef) {
2779   int isNumpyShape;
2780   int ret = MXIsNumpyShape(&isNumpyShape);
2781   SetIntField(env, compatibleRef, isNumpyShape);
2782   return ret;
2783 }
2784 
Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape(JNIEnv * env,jobject obj,jint isNpComp,jobject prevRef)2785 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape
2786   (JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) {
2787   int prev;
2788   int ret = MXSetIsNumpyShape(isNpComp, &prev);
2789   SetIntField(env, prevRef, prev);
2790   return ret;
2791 }
2792