1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file org_apache_mxnet_native_c_api.cc
22 * \brief JNI function implementations
23 */
24 #include "org_apache_mxnet_native_c_api.h" // generated by javah
25 #include <nnvm/c_api.h>
26 #include <mxnet/c_api.h>
27 #include <dmlc/logging.h>
28 #include <mxnet/ndarray.h>
29 #include <../src/common/cuda_utils.h>
30 #include <mutex>
31 #include <iostream>
32 #include <functional>
33 #include <string>
34 #include <unordered_map>
35 #include <vector>
36 #include "jni_helper_func.h"
37
38 JavaVM *_jvm;
39
Java_org_apache_mxnet_LibInfo_nativeLibInit(JNIEnv * env,jobject obj)40 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_nativeLibInit
41 (JNIEnv *env, jobject obj) {
42 return env->GetJavaVM(&_jvm);
43 }
44
Java_org_apache_mxnet_LibInfo_mxListAllOpNames(JNIEnv * env,jobject obj,jobject nameList)45 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListAllOpNames
46 (JNIEnv *env, jobject obj, jobject nameList) {
47 mx_uint outSize;
48 const char **outArray;
49 int ret = MXListAllOpNames(&outSize, &outArray);
50
51 jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer");
52 jmethodID listAppend = env->GetMethodID(listCls,
53 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
54 for (size_t i = 0; i < outSize; ++i) {
55 env->CallObjectMethod(nameList, listAppend, env->NewStringUTF(outArray[i]));
56 }
57
58 return ret;
59 }
60
Java_org_apache_mxnet_LibInfo_nnGetOpHandle(JNIEnv * env,jobject obj,jstring jopname,jobject jhandle)61 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_nnGetOpHandle
62 (JNIEnv *env, jobject obj, jstring jopname, jobject jhandle) {
63 OpHandle handle;
64 const char *opname = env->GetStringUTFChars(jopname, 0);
65 int ret = NNGetOpHandle(opname, &handle);
66 env->ReleaseStringUTFChars(jopname, opname);
67
68 jclass refClass = env->FindClass("org/apache/mxnet/Base$RefLong");
69 jfieldID refFid = env->GetFieldID(refClass, "value", "J");
70 env->SetLongField(jhandle, refFid, reinterpret_cast<jlong>(handle));
71
72 return ret;
73 }
74
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateNone(JNIEnv * env,jobject obj,jobject ndArrayHandle)75 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateNone
76 (JNIEnv *env, jobject obj, jobject ndArrayHandle) {
77 NDArrayHandle out;
78 int ret = MXNDArrayCreateNone(&out);
79 SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
80 return ret;
81 }
82
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx(JNIEnv * env,jobject obj,jintArray shape,jint ndim,jint devType,jint devId,jint delayAlloc,jint dtype,jobject ndArrayHandle)83 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx
84 (JNIEnv *env, jobject obj, jintArray shape, jint ndim, jint devType,
85 jint devId, jint delayAlloc, jint dtype, jobject ndArrayHandle) {
86 jint *shapeArr = env->GetIntArrayElements(shape, NULL);
87 NDArrayHandle out;
88 int ret = MXNDArrayCreateEx(reinterpret_cast<mx_uint *>(shapeArr), static_cast<mx_uint>(ndim),
89 devType, devId, delayAlloc, dtype, &out);
90 env->ReleaseIntArrayElements(shape, shapeArr, 0);
91 SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
92 return ret;
93 }
94
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx(JNIEnv * env,jobject obj,jint storageType,jintArray shape,jint ndim,jint devType,jint devId,jint delayAlloc,jint dtype,jint numAux,jintArray auxTypes,jintArray auxNdims,jintArray auxShapes,jobject ndArrayHandle)95 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx
96 (JNIEnv *env, jobject obj, jint storageType, jintArray shape, jint ndim, jint devType,
97 jint devId, jint delayAlloc, jint dtype, jint numAux, jintArray auxTypes,
98 jintArray auxNdims, jintArray auxShapes, jobject ndArrayHandle) {
99 jint *shapeArr = env->GetIntArrayElements(shape, NULL);
100 jint *auxTypesArr = env->GetIntArrayElements(auxTypes, NULL);
101 jint *auxNdimsArr = env->GetIntArrayElements(auxNdims, NULL);
102 jint *auxShapesArr = env->GetIntArrayElements(auxShapes, NULL);
103 NDArrayHandle out;
104 int ret = MXNDArrayCreateSparseEx(storageType,
105 reinterpret_cast<const mx_uint *>(shapeArr),
106 static_cast<mx_uint>(ndim),
107 devType, devId, delayAlloc, dtype,
108 static_cast<mx_uint>(numAux),
109 reinterpret_cast<int *>(auxTypesArr),
110 reinterpret_cast<mx_uint *>(auxNdimsArr),
111 reinterpret_cast<const mx_uint *>(auxShapesArr), &out);
112 env->ReleaseIntArrayElements(shape, shapeArr, 0);
113 env->ReleaseIntArrayElements(auxTypes, auxTypesArr, 0);
114 env->ReleaseIntArrayElements(auxNdims, auxNdimsArr, 0);
115 env->ReleaseIntArrayElements(auxShapes, auxShapesArr, 0);
116 SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
117 return ret;
118 }
119
Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv * env,jobject obj)120 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv *env, jobject obj) {
121 return MXNDArrayWaitAll();
122 }
123
Java_org_apache_mxnet_LibInfo_mxNDArrayWaitToRead(JNIEnv * env,jobject obj,jlong arrayPtr)124 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitToRead
125 (JNIEnv *env, jobject obj, jlong arrayPtr) {
126 return MXNDArrayWaitToRead(reinterpret_cast<NDArrayHandle>(arrayPtr));
127 }
128
Java_org_apache_mxnet_LibInfo_mxListFunctions(JNIEnv * env,jobject obj,jobject functions)129 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListFunctions
130 (JNIEnv *env, jobject obj, jobject functions) {
131 jclass longCls = env->FindClass("java/lang/Long");
132 jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
133
134 // scala.collection.mutable.ListBuffer append method
135 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
136 jmethodID listAppend = env->GetMethodID(listClass,
137 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
138
139 // Get function list
140 FunctionHandle *outArray;
141 mx_uint outSize;
142 int ret = MXListFunctions(&outSize, &outArray);
143 for (size_t i = 0; i < outSize; ++i) {
144 env->CallObjectMethod(functions, listAppend,
145 env->NewObject(longCls, longConst, outArray[i]));
146 }
147 return ret;
148 }
149
Java_org_apache_mxnet_LibInfo_mxFuncDescribe(JNIEnv * env,jobject obj,jlong funcPtr,jobject nUsedVars,jobject nScalars,jobject nMutateVars,jobject typeMask)150 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncDescribe
151 (JNIEnv *env, jobject obj, jlong funcPtr, jobject nUsedVars,
152 jobject nScalars, jobject nMutateVars, jobject typeMask) {
153 mx_uint numUseVars;
154 mx_uint numScalars;
155 mx_uint numMutateVars;
156 int type;
157 int ret = MXFuncDescribe(reinterpret_cast<FunctionHandle>(funcPtr), &numUseVars,
158 &numScalars, &numMutateVars, &type);
159
160 jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
161 jfieldID value = env->GetFieldID(refIntClass, "value", "I");
162 env->SetIntField(nUsedVars, value, static_cast<jint>(numUseVars));
163 env->SetIntField(nScalars, value, static_cast<jint>(numScalars));
164 env->SetIntField(nMutateVars, value, static_cast<jint>(numMutateVars));
165 env->SetIntField(typeMask, value, static_cast<jint>(type));
166
167 return ret;
168 }
169
Java_org_apache_mxnet_LibInfo_mxFuncGetInfo(JNIEnv * env,jobject obj,jlong funcPtr,jobject name,jobject desc,jobject numArgs,jobject argNames,jobject argTypes,jobject argDescs)170 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncGetInfo
171 (JNIEnv *env, jobject obj, jlong funcPtr, jobject name, jobject desc,
172 jobject numArgs, jobject argNames, jobject argTypes, jobject argDescs) {
173 const char *cName;
174 const char *cDesc;
175 mx_uint cNumArgs;
176 const char **cArgNames;
177 const char **cArgTypes;
178 const char **cArgDescs;
179 int ret = MXFuncGetInfo(reinterpret_cast<FunctionHandle>(funcPtr),
180 &cName, &cDesc, &cNumArgs,
181 &cArgNames, &cArgTypes, &cArgDescs);
182
183 jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
184 jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
185
186 jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
187 jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
188
189 // scala.collection.mutable.ListBuffer append method
190 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
191 jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
192 "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
193
194 env->SetObjectField(name, valueStr, env->NewStringUTF(cName));
195 env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc));
196 env->SetIntField(numArgs, valueInt, static_cast<jint>(cNumArgs));
197 for (size_t i = 0; i < cNumArgs; ++i) {
198 env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i]));
199 env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i]));
200 env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i]));
201 }
202
203 return ret;
204 }
205
Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx(JNIEnv * env,jobject obj,jlong funcPtr,jlongArray inputs,jlongArray outputsGiven,jobject outputs,jint numParams,jobjectArray paramKeys,jobjectArray paramVals,jobject outStypes)206 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx
207 (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray inputs,
208 jlongArray outputsGiven, jobject outputs, jint numParams,
209 jobjectArray paramKeys, jobjectArray paramVals, jobject outStypes) {
210
211 const char **cParamKeys = NULL;
212 const char **cParamVals = NULL;
213 if (numParams > 0) {
214 cParamKeys = new const char *[numParams];
215 cParamVals = new const char *[numParams];
216 for (int i = 0; i < numParams; i++) {
217 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
218 const char *key = env->GetStringUTFChars(jkey, 0);
219 cParamKeys[i] = key;
220 env->DeleteLocalRef(jkey);
221 jstring jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
222 const char *val = env->GetStringUTFChars(jval, 0);
223 cParamVals[i] = val;
224 env->DeleteLocalRef(jval);
225 }
226 }
227
228 int numOutputs = 0;
229 jlong *cOutputsGiven = NULL;
230 NDArrayHandle *cOutputs = NULL;
231 const int *cOutStypes;
232 if (outputsGiven) {
233 cOutputsGiven = env->GetLongArrayElements(outputsGiven, NULL);
234 cOutputs = reinterpret_cast<NDArrayHandle *>(cOutputsGiven);
235 numOutputs = static_cast<int>(env->GetArrayLength(outputsGiven));
236 }
237 jlong *cInputs = env->GetLongArrayElements(inputs, NULL);
238 jsize numInputs = env->GetArrayLength(inputs);
239 int ret = MXImperativeInvokeEx(reinterpret_cast<AtomicSymbolCreator>(funcPtr),
240 static_cast<int>(numInputs),
241 reinterpret_cast<NDArrayHandle *>(cInputs),
242 &numOutputs,
243 &cOutputs,
244 static_cast<int>(numParams),
245 cParamKeys,
246 cParamVals,
247 &cOutStypes);
248 env->ReleaseLongArrayElements(inputs, cInputs, 0);
249
250 // release allocated memory
251 if (numParams > 0) {
252 for (int i = 0; i < numParams; i++) {
253 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
254 env->ReleaseStringUTFChars(jkey, cParamKeys[i]);
255 env->DeleteLocalRef(jkey);
256 jstring jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
257 env->ReleaseStringUTFChars(jval, cParamVals[i]);
258 env->DeleteLocalRef(jval);
259 }
260 delete[] cParamKeys;
261 delete[] cParamVals;
262 }
263
264 if (cOutputs) {
265 jclass longCls = env->FindClass("java/lang/Long");
266 jclass intCls = env->FindClass("java/lang/Integer");
267 jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
268 jmethodID intConst = env->GetMethodID(intCls, "<init>", "(I)V");
269 // scala.collection.mutable.ListBuffer append method
270 jclass listClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
271 jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
272 "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
273 for (int i = 0; i < numOutputs; ++i) {
274 env->CallObjectMethod(outputs, listAppend,
275 env->NewObject(longCls, longConst,
276 reinterpret_cast<uint64_t>(cOutputs[i])));
277 env->CallObjectMethod(outStypes, listAppend,
278 env->NewObject(intCls, intConst,
279 cOutStypes[i]));
280 }
281 }
282
283 if (cOutputsGiven) {
284 env->ReleaseLongArrayElements(outputsGiven, cOutputsGiven, 0);
285 }
286
287 return ret;
288 }
289
Java_org_apache_mxnet_LibInfo_mxFuncInvoke(JNIEnv * env,jobject obj,jlong funcPtr,jlongArray useVars,jfloatArray scalarArgs,jlongArray mutateVars)290 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncInvoke
291 (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray useVars,
292 jfloatArray scalarArgs, jlongArray mutateVars) {
293 jlong *cUseVars = env->GetLongArrayElements(useVars, NULL);
294 jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL);
295 jlong *cMutateVars = env->GetLongArrayElements(mutateVars, NULL);
296 int ret = MXFuncInvoke(reinterpret_cast<FunctionHandle>(funcPtr),
297 reinterpret_cast<NDArrayHandle *>(cUseVars),
298 reinterpret_cast<mx_float *>(cScalarArgs),
299 reinterpret_cast<NDArrayHandle *>(cMutateVars));
300 env->ReleaseLongArrayElements(useVars, cUseVars, 0);
301 env->ReleaseFloatArrayElements(scalarArgs, cScalarArgs, 0);
302 env->ReleaseLongArrayElements(mutateVars, cMutateVars, 0);
303 return ret;
304 }
305
Java_org_apache_mxnet_LibInfo_mxFuncInvokeEx(JNIEnv * env,jobject obj,jlong funcPtr,jlongArray useVars,jfloatArray scalarArgs,jlongArray mutateVars,jint numParams,jobjectArray paramKeys,jobjectArray paramVals)306 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFuncInvokeEx
307 (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray useVars,
308 jfloatArray scalarArgs, jlongArray mutateVars,
309 jint numParams, jobjectArray paramKeys, jobjectArray paramVals) {
310 jlong *cUseVars = env->GetLongArrayElements(useVars, NULL);
311 jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL);
312 jlong *cMutateVars = env->GetLongArrayElements(mutateVars, NULL);
313 jbyte **cParamKeys = NULL;
314 jbyte **cParamVals = NULL;
315 if (numParams > 0) {
316 cParamKeys = new jbyte *[numParams];
317 cParamVals = new jbyte *[numParams];
318 for (int i = 0; i < numParams; i++) {
319 jbyteArray jkey = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramKeys, i));
320 jbyte *cParamKey = env->GetByteArrayElements(jkey, NULL);
321 cParamKeys[i] = cParamKey;
322 env->DeleteLocalRef(jkey);
323 jbyteArray jval = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramVals, i));
324 jbyte *cParamVal = env->GetByteArrayElements(jval, NULL);
325 cParamVals[i] = cParamVal;
326 env->DeleteLocalRef(jval);
327 }
328 }
329 int ret = MXFuncInvokeEx(reinterpret_cast<FunctionHandle>(funcPtr),
330 reinterpret_cast<NDArrayHandle *>(cUseVars),
331 reinterpret_cast<mx_float *>(cScalarArgs),
332 reinterpret_cast<NDArrayHandle *>(cMutateVars),
333 static_cast<int>(numParams),
334 reinterpret_cast<char **>(cParamKeys),
335 reinterpret_cast<char **>(cParamVals));
336 env->ReleaseLongArrayElements(useVars, cUseVars, 0);
337 env->ReleaseFloatArrayElements(scalarArgs, cScalarArgs, 0);
338 env->ReleaseLongArrayElements(mutateVars, cMutateVars, 0);
339 if (numParams > 0) {
340 for (int i = 0; i < numParams; i++) {
341 jbyteArray jkey = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramKeys, i));
342 env->ReleaseByteArrayElements(jkey, cParamKeys[i], 0);
343 env->DeleteLocalRef(jkey);
344 jbyteArray jval = reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(paramVals, i));
345 env->ReleaseByteArrayElements(jval, cParamVals[i], 0);
346 env->DeleteLocalRef(jval);
347 }
348 delete[] cParamKeys;
349 delete[] cParamVals;
350 }
351 return ret;
352 }
353
Java_org_apache_mxnet_LibInfo_mxNDArraySaveRawBytes(JNIEnv * env,jobject obj,jlong ndArrayPtr,jobject dataBuf)354 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySaveRawBytes
355 (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject dataBuf) {
356 size_t length;
357 const char *pdata;
358 int ret = MXNDArraySaveRawBytes(reinterpret_cast<NDArrayHandle>(ndArrayPtr), &length, &pdata);
359
360 // fill dataBuf
361 jclass byteClass = env->FindClass("java/lang/Byte");
362 jmethodID newByte = env->GetMethodID(byteClass, "<init>", "(B)V");
363 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
364 jmethodID arrayAppend = env->GetMethodID(arrayClass,
365 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
366 for (size_t i = 0; i < length; ++i) {
367 jobject data = env->NewObject(byteClass, newByte, static_cast<jbyte>(pdata[i]));
368 env->CallObjectMethod(dataBuf, arrayAppend, data);
369 env->DeleteLocalRef(data);
370 }
371
372 return ret;
373 }
374
Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes(JNIEnv * env,jobject obj,jbyteArray bytes,jobject handleRef)375 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes
376 (JNIEnv *env, jobject obj, jbyteArray bytes, jobject handleRef) {
377 int size = env->GetArrayLength(bytes);
378 jbyte *byteArr = env->GetByteArrayElements(bytes, NULL);
379 NDArrayHandle out;
380 int ret = MXNDArrayLoadFromRawBytes(reinterpret_cast<const void *>(byteArr),
381 static_cast<size_t>(size), &out);
382 env->ReleaseByteArrayElements(bytes, byteArr, 0);
383 SetLongField(env, handleRef, reinterpret_cast<jlong>(out));
384 return ret;
385 }
386
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv * env,jobject obj,jlong ndArrayPtr,jobject ndimRef,jobject dataBuf)387 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
388 (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject ndimRef, jobject dataBuf) {
389 int ndim;
390 const int *pdata;
391 int ret = MXNDArrayGetShapeEx(reinterpret_cast<NDArrayHandle>(ndArrayPtr), &ndim, &pdata);
392
393 // fill dataBuf
394 jclass integerClass = env->FindClass("java/lang/Integer");
395 jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V");
396
397 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
398 jmethodID arrayAppend = env->GetMethodID(arrayClass,
399 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
400 for (int i = 0; i < ndim; ++i) {
401 jobject data = env->NewObject(integerClass, newInteger, pdata[i]);
402 env->CallObjectMethod(dataBuf, arrayAppend, data);
403 env->DeleteLocalRef(data);
404 }
405
406 // set ndimRef
407 jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
408 jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
409 env->SetIntField(ndimRef, valueInt, ndim);
410
411 return ret;
412 }
413
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray(JNIEnv * env,jobject obj,jlong dstPtr,jlong srcPtr,jint locator)414 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray
415 (JNIEnv *env, jobject obj, jlong dstPtr, jlong srcPtr, jint locator) {
416 int ret = MXNDArraySyncCopyFromNDArray(reinterpret_cast<NDArrayHandle>(dstPtr),
417 reinterpret_cast<NDArrayHandle>(srcPtr),
418 locator);
419 return ret;
420 }
421
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU(JNIEnv * env,jobject obj,jlong ndArrayPtr,jbyteArray data,jint size)422 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU
423 (JNIEnv *env, jobject obj, jlong ndArrayPtr, jbyteArray data, jint size) {
424 jbyte *pdata = env->GetByteArrayElements(data, NULL);
425 int ret = MXNDArraySyncCopyToCPU(reinterpret_cast<NDArrayHandle>(ndArrayPtr),
426 reinterpret_cast<void *>(pdata), size);
427 env->ReleaseByteArrayElements(data, pdata, 0); // copy back to java array automatically
428 return ret;
429 }
430
Java_org_apache_mxnet_LibInfo_mxNDArraySlice(JNIEnv * env,jobject obj,jlong ndArrayPtr,jint start,jint end,jobject slicedHandle)431 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySlice
432 (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint start, jint end, jobject slicedHandle) {
433 NDArrayHandle out;
434 int ret = MXNDArraySlice(reinterpret_cast<NDArrayHandle>(ndArrayPtr), start, end, &out);
435 SetLongField(env, slicedHandle, reinterpret_cast<jlong>(out));
436 return ret;
437 }
438
Java_org_apache_mxnet_LibInfo_mxNDArrayAt(JNIEnv * env,jobject obj,jlong ndArrayPtr,jint idx,jobject jout)439 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt
440 (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint idx, jobject jout) {
441 NDArrayHandle out;
442 int ret = MXNDArrayAt(reinterpret_cast<NDArrayHandle>(ndArrayPtr), idx, &out);
443 SetLongField(env, jout, reinterpret_cast<jlong>(out));
444 return ret;
445 }
446
Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64(JNIEnv * env,jobject obj,jlong ndArrayPtr,jint ndim,jlongArray dims,jboolean reverse,jobject reshapedHandle)447 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
448 (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim,
449 jlongArray dims, jboolean reverse, jobject reshapedHandle) {
450 NDArrayHandle out;
451 jlong *pdims = env->GetLongArrayElements(dims, NULL);
452 int ret = MXNDArrayReshape64(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
453 reinterpret_cast<dim_t *>(pdims), reverse, &out);
454 SetLongField(env, reshapedHandle, reinterpret_cast<jlong>(out));
455 env->ReleaseLongArrayElements(dims, pdims, 0);
456 return ret;
457 }
458
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU(JNIEnv * env,jobject obj,jlong arrayPtr,jfloatArray sourceArr,jint arrSize)459 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
460 (JNIEnv *env, jobject obj, jlong arrayPtr, jfloatArray sourceArr, jint arrSize) {
461 jfloat *sourcePtr = env->GetFloatArrayElements(sourceArr, NULL);
462 int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
463 static_cast<const mx_float *>(sourcePtr), arrSize);
464 env->ReleaseFloatArrayElements(sourceArr, sourcePtr, 0);
465 return ret;
466 }
467
Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU(JNIEnv * env,jobject obj,jlong arrayPtr,jdoubleArray sourceArr,jint arrSize)468 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
469 (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) {
470 jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL);
471 int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
472 static_cast<const double *>(sourcePtr), arrSize);
473 env->ReleaseDoubleArrayElements(sourceArr, sourcePtr, 0);
474 return ret;
475 }
476
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray(JNIEnv * env,jobject obj,jlong arrayPtr,jobject ndArrayHandle)477 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
478 (JNIEnv *env, jobject obj, jlong arrayPtr, jobject ndArrayHandle) {
479 NDArrayHandle out;
480 int ret = MXNDArrayGetDataNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
481 &out);
482 SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
483 return ret;
484 }
485
Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray(JNIEnv * env,jobject obj,jlong arrayPtr,jint location,jobject ndArrayHandle)486 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
487 (JNIEnv *env, jobject obj, jlong arrayPtr, jint location, jobject ndArrayHandle) {
488 NDArrayHandle out;
489 int ret = MXNDArrayGetAuxNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
490 static_cast<mx_uint>(location),
491 &out);
492 SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
493 return ret;
494 }
495
Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext(JNIEnv * env,jobject obj,jlong arrayPtr,jobject devTypeId,jobject devId)496 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
497 (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) {
498 int outDevType;
499 int outDevId;
500 int ret = MXNDArrayGetContext(reinterpret_cast<NDArrayHandle>(arrayPtr), &outDevType, &outDevId);
501 jclass refClass = env->FindClass("org/apache/mxnet/Base$RefInt");
502 jfieldID refFid = env->GetFieldID(refClass, "value", "I");
503 env->SetIntField(devTypeId, refFid, outDevType);
504 env->SetIntField(devId, refFid, outDevId);
505 return ret;
506 }
507
Java_org_apache_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env,jobject obj,jlong ndArrayHandle)508 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayFree
509 (JNIEnv * env, jobject obj, jlong ndArrayHandle) {
510 return MXNDArrayFree(reinterpret_cast<NDArrayHandle>(ndArrayHandle));
511 }
512
Java_org_apache_mxnet_LibInfo_mxNDArrayLoad(JNIEnv * env,jobject obj,jstring jfname,jobject joutSize,jobject jhandles,jobject joutNameSize,jobject jnames)513 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayLoad
514 (JNIEnv * env, jobject obj, jstring jfname, jobject joutSize,
515 jobject jhandles, jobject joutNameSize, jobject jnames) {
516 mx_uint outSize;
517 NDArrayHandle *outArr;
518 mx_uint outNameSize;
519 const char **outNames;
520
521 const char *fname = env->GetStringUTFChars(jfname, 0);
522 int ret = MXNDArrayLoad(fname, &outSize, &outArr, &outNameSize, &outNames);
523 env->ReleaseStringUTFChars(jfname, fname);
524
525 if (ret) {
526 return ret;
527 }
528
529 // fill sizes
530 jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
531 jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
532 env->SetIntField(joutSize, valueInt, outSize);
533 env->SetIntField(joutNameSize, valueInt, outNameSize);
534
535 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
536 jmethodID arrayAppend = env->GetMethodID(arrayClass,
537 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
538
539 // fill handles
540 jclass longCls = env->FindClass("java/lang/Long");
541 jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
542 for (size_t i = 0; i < outSize; ++i) {
543 jobject handle = env->NewObject(longCls, longConst, outArr[i]);
544 env->CallObjectMethod(jhandles, arrayAppend, handle);
545 env->DeleteLocalRef(handle);
546 }
547
548 // fill names
549 for (size_t i = 0; i < outNameSize; ++i) {
550 jstring jname = env->NewStringUTF(outNames[i]);
551 env->CallObjectMethod(jnames, arrayAppend, jname);
552 env->DeleteLocalRef(jname);
553 }
554
555 return ret;
556 }
557
Java_org_apache_mxnet_LibInfo_mxNDArraySave(JNIEnv * env,jobject obj,jstring jfname,jlongArray jhandles,jobjectArray jkeys)558 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySave
559 (JNIEnv * env, jobject obj, jstring jfname, jlongArray jhandles, jobjectArray jkeys) {
560 int numArgs = env->GetArrayLength(jhandles);
561 const char **keys = NULL;
562 if (jkeys != NULL) {
563 keys = new const char *[numArgs];
564 for (int i = 0; i < numArgs; i++) {
565 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
566 const char *key = env->GetStringUTFChars(jkey, 0);
567 keys[i] = key;
568 env->DeleteLocalRef(jkey);
569 }
570 }
571
572 const char *fname = env->GetStringUTFChars(jfname, 0);
573 jlong *handles = env->GetLongArrayElements(jhandles, NULL);
574
575 int ret = MXNDArraySave(fname, static_cast<mx_uint>(numArgs),
576 reinterpret_cast<NDArrayHandle *>(handles), keys);
577
578 env->ReleaseLongArrayElements(jhandles, handles, 0);
579 env->ReleaseStringUTFChars(jfname, fname);
580
581 // release allocated memory
582 if (jkeys != NULL) {
583 for (int i = 0; i < numArgs; i++) {
584 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
585 env->ReleaseStringUTFChars(jkey, keys[i]);
586 env->DeleteLocalRef(jkey);
587 }
588 delete[] keys;
589 }
590
591 return ret;
592 }
593
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType(JNIEnv * env,jobject obj,jlong jhandle,jobject jdtype)594 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType
595 (JNIEnv * env, jobject obj, jlong jhandle, jobject jdtype) {
596 int dtype;
597 int ret = MXNDArrayGetDType(reinterpret_cast<NDArrayHandle>(jhandle), &dtype);
598 SetIntField(env, jdtype, dtype);
599 return ret;
600 }
601
Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType(JNIEnv * env,jobject obj,jlong jhandle,jobject jstype)602 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType
603 (JNIEnv * env, jobject obj, jlong jhandle, jobject jstype) {
604 int stype;
605 int ret = MXNDArrayGetStorageType(reinterpret_cast<NDArrayHandle>(jhandle), &stype);
606 SetIntField(env, jstype, stype);
607 return ret;
608 }
609
Java_org_apache_mxnet_LibInfo_mxInitPSEnv(JNIEnv * env,jobject obj,jobjectArray jkeys,jobjectArray jvals)610 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxInitPSEnv
611 (JNIEnv *env, jobject obj, jobjectArray jkeys, jobjectArray jvals) {
612 // keys and values
613 int paramSize = env->GetArrayLength(jkeys);
614 const char** keys = new const char*[paramSize];
615 const char** vals = new const char*[paramSize];
616 jstring jkey, jval;
617 // use strcpy and release char* created by JNI inplace
618 for (int i = 0; i < paramSize; i++) {
619 jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
620 const char* ckey = env->GetStringUTFChars(jkey, 0);
621 keys[i] = ckey;
622 env->DeleteLocalRef(jkey);
623
624 jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
625 const char* cval = env->GetStringUTFChars(jval, 0);
626 vals[i] = cval;
627 env->DeleteLocalRef(jval);
628 }
629
630 int ret = MXInitPSEnv(static_cast<mx_uint>(paramSize),
631 static_cast<const char**>(keys),
632 static_cast<const char**>(vals));
633
634 // release keys and vals
635 for (int i = 0; i < paramSize; i++) {
636 jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
637 env->ReleaseStringUTFChars(key, keys[i]);
638 env->DeleteLocalRef(key);
639
640 jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
641 env->ReleaseStringUTFChars(value, vals[i]);
642 env->DeleteLocalRef(value);
643 }
644 delete[] keys;
645 delete[] vals;
646
647 return ret;
648 }
649
KVStoreServerControllerFunc(int head,const char * body,void * handle)650 extern "C" void KVStoreServerControllerFunc
651 (int head, const char *body, void *handle) {
652 jobject controllerObjGlb = static_cast<jobject>(handle);
653
654 JNIEnv *env;
655 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
656
657 // find java controller method
658 jclass ctrlClass = env->GetObjectClass(controllerObjGlb);
659 jmethodID ctrlFunc = env->GetMethodID(ctrlClass, "invoke", "(ILjava/lang/String;)V");
660
661 jstring jbody = env->NewStringUTF(body);
662 env->CallVoidMethod(controllerObjGlb, ctrlFunc, head, jbody);
663 env->DeleteLocalRef(jbody);
664
665 env->DeleteLocalRef(ctrlClass);
666 // FIXME(Yizhi): This function can be called multiple times,
667 // can we find a way to safely destroy this object ?
668 // env->DeleteGlobalRef(controllerObjGlb);
669 }
670
Java_org_apache_mxnet_LibInfo_mxKVStoreRunServer(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject controllerObj)671 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreRunServer
672 (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject controllerObj) {
673 jobject controllerObjGlb = env->NewGlobalRef(controllerObj);
674 return MXKVStoreRunServer(reinterpret_cast<KVStoreHandle>(kvStorePtr),
675 KVStoreServerControllerFunc,
676 reinterpret_cast<void *>(controllerObjGlb));
677 }
678
KVStoreUpdaterCallbackFunc(int key,NDArrayHandle recv,NDArrayHandle local,void * handle)679 extern "C" void KVStoreUpdaterCallbackFunc
680 (int key, NDArrayHandle recv, NDArrayHandle local, void *handle) {
681 jobject updaterFuncObjGlb = static_cast<jobject>(handle);
682
683 JNIEnv *env;
684 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
685
686 // find java updater method
687 jclass updtClass = env->GetObjectClass(updaterFuncObjGlb);
688 jmethodID updtFunc = env->GetMethodID(updtClass,
689 "update", "(ILorg/apache/mxnet/NDArray;Lorg/apache/mxnet/NDArray;)V");
690
691 // find java NDArray constructor
692 jclass ndObjClass = env->FindClass("org/apache/mxnet/NDArray");
693 jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "<init>", "(JZZ)V");
694
695 jobject ndRecv = env->NewObject(ndObjClass, ndObjConstructor,
696 reinterpret_cast<jlong>(recv), true);
697 jobject ndLocal = env->NewObject(ndObjClass, ndObjConstructor,
698 reinterpret_cast<jlong>(local), true);
699
700 env->CallVoidMethod(updaterFuncObjGlb, updtFunc, key, ndRecv, ndLocal);
701
702 env->DeleteLocalRef(ndLocal);
703 env->DeleteLocalRef(ndRecv);
704 env->DeleteLocalRef(ndObjClass);
705 env->DeleteLocalRef(updtClass);
706 // FIXME(Yizhi): This function can be called multiple times,
707 // can we find a way to safely destroy this object ?
708 // env->DeleteGlobalRef(updaterFuncObjGlb);
709 }
710
Java_org_apache_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject updaterFuncObj)711 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSetUpdater
712 (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj) {
713 jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj);
714 return MXKVStoreSetUpdater(reinterpret_cast<KVStoreHandle>(kvStorePtr),
715 KVStoreUpdaterCallbackFunc,
716 reinterpret_cast<void *>(updaterFuncObjGlb));
717 }
718
Java_org_apache_mxnet_LibInfo_mxKVStoreIsWorkerNode(JNIEnv * env,jobject obj,jobject isWorkerRef)719 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreIsWorkerNode
720 (JNIEnv *env, jobject obj, jobject isWorkerRef) {
721 int isWorker;
722 int ret = MXKVStoreIsWorkerNode(&isWorker);
723 SetIntField(env, isWorkerRef, isWorker);
724 return ret;
725 }
726
Java_org_apache_mxnet_LibInfo_mxKVStoreCreate(JNIEnv * env,jobject obj,jstring name,jobject kvStoreHandle)727 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreCreate
728 (JNIEnv *env, jobject obj, jstring name, jobject kvStoreHandle) {
729 jclass refLongClass = env->FindClass("org/apache/mxnet/Base$RefLong");
730 jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
731
732 KVStoreHandle out;
733 const char *type = env->GetStringUTFChars(name, 0);
734 int ret = MXKVStoreCreate(type, &out);
735 env->ReleaseStringUTFChars(name, type);
736
737 env->SetLongField(kvStoreHandle, refLongFid, reinterpret_cast<jlong>(out));
738 return ret;
739 }
740
Java_org_apache_mxnet_LibInfo_mxKVStoreInit(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jintArray keys,jlongArray values)741 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreInit
742 (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys, jlongArray values) {
743 jint *keyArray = env->GetIntArrayElements(keys, NULL);
744 jlong *valueArray = env->GetLongArrayElements(values, NULL);
745 int ret = MXKVStoreInit(reinterpret_cast<KVStoreHandle>(kvStorePtr),
746 static_cast<mx_uint>(len),
747 static_cast<const int *>(keyArray),
748 reinterpret_cast<NDArrayHandle *>(valueArray));
749 env->ReleaseIntArrayElements(keys, keyArray, 0);
750 env->ReleaseLongArrayElements(values, valueArray, 0);
751 return ret;
752 }
753
Java_org_apache_mxnet_LibInfo_mxKVStoreInitEx(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jobjectArray keys,jlongArray values)754 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreInitEx
755 (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys, jlongArray values) {
756 const char **keyArray = new const char *[len];
757 for (int i = 0; i < len; i++) {
758 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
759 const char *key = env->GetStringUTFChars(jkey, 0);
760 keyArray[i] = key;
761 env->DeleteLocalRef(jkey);
762 }
763 jlong *valueArray = env->GetLongArrayElements(values, NULL);
764 int ret = MXKVStoreInitEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
765 static_cast<mx_uint>(len),
766 keyArray,
767 reinterpret_cast<NDArrayHandle *>(valueArray));
768 env->ReleaseLongArrayElements(values, valueArray, 0);
769 for (int i = 0; i < len; i++) {
770 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
771 env->ReleaseStringUTFChars(jkey, keyArray[i]);
772 env->DeleteLocalRef(jkey);
773 }
774 delete[] keyArray;
775 return ret;
776 }
777
Java_org_apache_mxnet_LibInfo_mxKVStorePush(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jintArray keys,jlongArray values,jint priority)778 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePush
779 (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
780 jlongArray values, jint priority) {
781 jint *keyArray = env->GetIntArrayElements(keys, NULL);
782 jlong *valueArray = env->GetLongArrayElements(values, NULL);
783 int ret = MXKVStorePush(reinterpret_cast<KVStoreHandle>(kvStorePtr),
784 static_cast<mx_uint>(len),
785 static_cast<const int *>(keyArray),
786 reinterpret_cast<NDArrayHandle *>(valueArray),
787 priority);
788 env->ReleaseLongArrayElements(values, valueArray, 0);
789 return ret;
790 }
791
Java_org_apache_mxnet_LibInfo_mxKVStorePushEx(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jobjectArray keys,jlongArray values,jint priority)792 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePushEx
793 (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
794 jlongArray values, jint priority) {
795 const char **keyArray = new const char *[len];
796 for (int i = 0; i < len; i++) {
797 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
798 const char *key = env->GetStringUTFChars(jkey, 0);
799 keyArray[i] = key;
800 env->DeleteLocalRef(jkey);
801 }
802 jlong *valueArray = env->GetLongArrayElements(values, NULL);
803 int ret = MXKVStorePushEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
804 static_cast<mx_uint>(len),
805 keyArray,
806 reinterpret_cast<NDArrayHandle *>(valueArray),
807 priority);
808 env->ReleaseLongArrayElements(values, valueArray, 0);
809 for (int i = 0; i < len; i++) {
810 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
811 env->ReleaseStringUTFChars(jkey, keyArray[i]);
812 env->DeleteLocalRef(jkey);
813 }
814 delete[] keyArray;
815 return ret;
816 }
817
Java_org_apache_mxnet_LibInfo_mxKVStorePull(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jintArray keys,jlongArray outs,jint priority)818 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePull
819 (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
820 jlongArray outs, jint priority) {
821 jint *keyArray = env->GetIntArrayElements(keys, NULL);
822 jlong *outArray = env->GetLongArrayElements(outs, NULL);
823 int ret = MXKVStorePull(reinterpret_cast<KVStoreHandle>(kvStorePtr),
824 static_cast<mx_uint>(len),
825 static_cast<const int *>(keyArray),
826 reinterpret_cast<NDArrayHandle *>(outArray),
827 priority);
828 env->ReleaseIntArrayElements(keys, keyArray, 0);
829 env->ReleaseLongArrayElements(outs, outArray, 0);
830 return ret;
831 }
832
Java_org_apache_mxnet_LibInfo_mxKVStorePullEx(JNIEnv * env,jobject obj,jlong kvStorePtr,jint len,jobjectArray keys,jlongArray outs,jint priority)833 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStorePullEx
834 (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
835 jlongArray outs, jint priority) {
836 const char **keyArray = new const char *[len];
837 for (int i = 0; i < len; i++) {
838 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
839 const char *key = env->GetStringUTFChars(jkey, 0);
840 keyArray[i] = key;
841 env->DeleteLocalRef(jkey);
842 }
843 jlong *outArray = env->GetLongArrayElements(outs, NULL);
844 int ret = MXKVStorePullEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
845 static_cast<mx_uint>(len),
846 keyArray,
847 reinterpret_cast<NDArrayHandle *>(outArray),
848 priority);
849 env->ReleaseLongArrayElements(outs, outArray, 0);
850 for (int i = 0; i < len; i++) {
851 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
852 env->ReleaseStringUTFChars(jkey, keyArray[i]);
853 env->DeleteLocalRef(jkey);
854 }
855 delete[] keyArray;
856 return ret;
857 }
858
Java_org_apache_mxnet_LibInfo_mxKVStoreGetType(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject kvType)859 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetType
860 (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject kvType) {
861 const char *type;
862 int ret = MXKVStoreGetType(reinterpret_cast<KVStoreHandle>(kvStorePtr), &type);
863 jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
864 jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
865 env->SetObjectField(kvType, valueStr, env->NewStringUTF(type));
866 return ret;
867 }
868
Java_org_apache_mxnet_LibInfo_mxKVStoreSendCommmandToServers(JNIEnv * env,jobject obj,jlong kvStorePtr,jint head,jstring body)869 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSendCommmandToServers
870 (JNIEnv *env, jobject obj, jlong kvStorePtr, jint head, jstring body) {
871 const char *bodyCStr = env->GetStringUTFChars(body, 0);
872 int ret = MXKVStoreSendCommmandToServers(
873 reinterpret_cast<KVStoreHandle>(kvStorePtr), head, bodyCStr);
874 env->ReleaseStringUTFChars(body, bodyCStr);
875 return ret;
876 }
877
Java_org_apache_mxnet_LibInfo_mxKVStoreBarrier(JNIEnv * env,jobject obj,jlong kvStorePtr)878 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreBarrier
879 (JNIEnv *env, jobject obj, jlong kvStorePtr) {
880 return MXKVStoreBarrier((KVStoreHandle)kvStorePtr);
881 }
882
Java_org_apache_mxnet_LibInfo_mxKVStoreGetGroupSize(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject sizeRef)883 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetGroupSize
884 (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject sizeRef) {
885 int size;
886 int ret = MXKVStoreGetGroupSize(reinterpret_cast<KVStoreHandle>(kvStorePtr), &size);
887 SetIntField(env, sizeRef, size);
888 return ret;
889 }
890
Java_org_apache_mxnet_LibInfo_mxKVStoreGetRank(JNIEnv * env,jobject obj,jlong kvStorePtr,jobject rankRef)891 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetRank
892 (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject rankRef) {
893 int rank;
894 int ret = MXKVStoreGetRank(reinterpret_cast<KVStoreHandle>(kvStorePtr), &rank);
895 SetIntField(env, rankRef, rank);
896 return ret;
897 }
898
Java_org_apache_mxnet_LibInfo_mxKVStoreGetNumDeadNode(JNIEnv * env,jobject obj,jlong kvStorePtr,jint nodeId,jobject numberRef)899 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreGetNumDeadNode
900 (JNIEnv * env, jobject obj, jlong kvStorePtr, jint nodeId, jobject numberRef) {
901 int number;
902 int ret = MXKVStoreGetNumDeadNode(reinterpret_cast<KVStoreHandle>(kvStorePtr),
903 static_cast<const int>(nodeId),
904 &number);
905 SetIntField(env, numberRef, number);
906 return ret;
907 }
908
Java_org_apache_mxnet_LibInfo_mxKVStoreSetBarrierBeforeExit(JNIEnv * env,jobject obj,jlong kvStorePtr,jint doBarrier)909 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreSetBarrierBeforeExit
910 (JNIEnv * env, jobject obj, jlong kvStorePtr, jint doBarrier) {
911 return MXKVStoreSetBarrierBeforeExit(reinterpret_cast<KVStoreHandle>(kvStorePtr),
912 static_cast<const int>(doBarrier));
913 }
914
Java_org_apache_mxnet_LibInfo_mxKVStoreFree(JNIEnv * env,jobject obj,jlong ptr)915 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxKVStoreFree
916 (JNIEnv * env, jobject obj, jlong ptr) {
917 return MXKVStoreFree(reinterpret_cast<KVStoreHandle>(ptr));
918 }
919
Java_org_apache_mxnet_LibInfo_mxExecutorOutputs(JNIEnv * env,jobject obj,jlong executorPtr,jobject outputs)920 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorOutputs
921 (JNIEnv *env, jobject obj, jlong executorPtr, jobject outputs) {
922 mx_uint outSize;
923 NDArrayHandle *out;
924 int ret = MXExecutorOutputs(reinterpret_cast<ExecutorHandle>(executorPtr), &outSize, &out);
925
926 jclass longCls = env->FindClass("java/lang/Long");
927 jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
928
929 // fill java outputs
930 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
931 jmethodID arrayAppend = env->GetMethodID(arrayClass,
932 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
933 for (size_t i = 0; i < outSize; ++i) {
934 env->CallObjectMethod(outputs, arrayAppend,
935 env->NewObject(longCls, longConst, out[i]));
936 }
937
938 return ret;
939 }
940
Java_org_apache_mxnet_LibInfo_mxExecutorFree(JNIEnv * env,jobject obj,jlong ptr)941 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorFree
942 (JNIEnv * env, jobject obj, jlong ptr) {
943 return MXExecutorFree(reinterpret_cast<ExecutorHandle>(ptr));
944 }
945
Java_org_apache_mxnet_LibInfo_mxExecutorForward(JNIEnv * env,jobject obj,jlong ptr,jint isTrain)946 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorForward
947 (JNIEnv * env, jobject obj, jlong ptr, jint isTrain) {
948 return MXExecutorForward(reinterpret_cast<ExecutorHandle>(ptr), static_cast<int>(isTrain));
949 }
950
Java_org_apache_mxnet_LibInfo_mxExecutorBackward(JNIEnv * env,jobject obj,jlong executorPtr,jlongArray grads)951 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBackward
952 (JNIEnv * env, jobject obj, jlong executorPtr, jlongArray grads) {
953 int gradsSize = env->GetArrayLength(grads);
954 jlong *gradArr = env->GetLongArrayElements(grads, NULL);
955 int ret = MXExecutorBackward(reinterpret_cast<ExecutorHandle>(executorPtr),
956 static_cast<mx_uint>(gradsSize),
957 reinterpret_cast<NDArrayHandle *>(gradArr));
958 env->ReleaseLongArrayElements(grads, gradArr, 0);
959 return ret;
960 }
961
Java_org_apache_mxnet_LibInfo_mxExecutorReshape(JNIEnv * env,jobject obj,jint partialReshaping,jint allowUpSizing,jint devType,jint devId,jobjectArray jmapKeys,jintArray jmapDevTypes,jintArray jmapDevIds,jobjectArray jprovidedArgShapeNames,jintArray jprovidedArgShapeData,jintArray jprovidedArgShapeIdx,jobject jrefInArgs,jobject jrefArgGrads,jobject jrefAuxStates,jlong jsharedExec,jobject jrefOut)962 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape
963 (JNIEnv * env, jobject obj,
964 jint partialReshaping, jint allowUpSizing, jint devType, jint devId,
965 jobjectArray jmapKeys, jintArray jmapDevTypes, jintArray jmapDevIds,
966 jobjectArray jprovidedArgShapeNames, jintArray jprovidedArgShapeData,
967 jintArray jprovidedArgShapeIdx, jobject jrefInArgs, jobject jrefArgGrads,
968 jobject jrefAuxStates, jlong jsharedExec, jobject jrefOut) {
969 CHECK(jmapKeys != NULL);
970 CHECK(jprovidedArgShapeNames != NULL);
971
972 int numMapKeys = env->GetArrayLength(jmapKeys);
973 jint *mapDevTypes = env->GetIntArrayElements(jmapDevTypes, NULL);
974 jint *mapDevIds = env->GetIntArrayElements(jmapDevIds, NULL);
975 const char **mapKeys = NULL;
976 if (numMapKeys > 0) {
977 mapKeys = new const char*[numMapKeys];
978 for (int i = 0; i < numMapKeys; ++i) {
979 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
980 mapKeys[i] = env->GetStringUTFChars(jkey, 0);
981 env->DeleteLocalRef(jkey);
982 }
983 }
984
985 int numProvidedArgShapes = env->GetArrayLength(jprovidedArgShapeNames);
986 jint *providedArgShapeData = env->GetIntArrayElements(jprovidedArgShapeData, NULL);
987 jint *providedArgShapeIdx = env->GetIntArrayElements(jprovidedArgShapeIdx, NULL);
988 const char **providedArgShapeNames = NULL;
989 if (numProvidedArgShapes > 0) {
990 providedArgShapeNames = new const char*[numProvidedArgShapes];
991 for (int i = 0; i < numProvidedArgShapes; ++i) {
992 jstring jkey = reinterpret_cast<jstring>(
993 env->GetObjectArrayElement(jprovidedArgShapeNames, i));
994 providedArgShapeNames[i] = env->GetStringUTFChars(jkey, 0);
995 env->DeleteLocalRef(jkey);
996 }
997 }
998
999 mx_uint numInArgs = 0;
1000 NDArrayHandle *inArgs;
1001 NDArrayHandle *argGrads;
1002
1003 mx_uint numAuxStates = 0;
1004 NDArrayHandle *auxStates;
1005
1006 ExecutorHandle out;
1007
1008 int ret = MXExecutorReshapeEx(partialReshaping,
1009 allowUpSizing,
1010 devType,
1011 devId,
1012 static_cast<mx_uint>(numMapKeys),
1013 mapKeys,
1014 static_cast<const int*>(mapDevTypes),
1015 static_cast<const int*>(mapDevIds),
1016 static_cast<const mx_uint>(numProvidedArgShapes),
1017 providedArgShapeNames,
1018 static_cast<const int*>(providedArgShapeData),
1019 reinterpret_cast<const mx_uint*>(providedArgShapeIdx),
1020 &numInArgs,
1021 &inArgs,
1022 &argGrads,
1023 &numAuxStates,
1024 &auxStates,
1025 reinterpret_cast<ExecutorHandle>(jsharedExec),
1026 &out);
1027
1028 jclass longCls = env->FindClass("java/lang/Long");
1029 jmethodID newLong = env->GetMethodID(longCls, "<init>", "(J)V");
1030
1031 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1032 jmethodID arrayAppend = env->GetMethodID(arrayClass,
1033 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1034
1035 for (size_t i = 0; i < numInArgs; ++i) {
1036 jobject inArg = env->NewObject(longCls, newLong, inArgs[i]);
1037 env->CallObjectMethod(jrefInArgs, arrayAppend, inArg);
1038 env->DeleteLocalRef(inArg);
1039
1040 jobject argGrad = env->NewObject(longCls, newLong, argGrads[i]);
1041 env->CallObjectMethod(jrefArgGrads, arrayAppend, argGrad);
1042 env->DeleteLocalRef(argGrad);
1043 }
1044
1045 for (size_t i = 0; i < numAuxStates; ++i) {
1046 jobject auxState = env->NewObject(longCls, newLong, auxStates[i]);
1047 env->CallObjectMethod(jrefAuxStates, arrayAppend, auxState);
1048 env->DeleteLocalRef(auxState);
1049 }
1050
1051 SetLongField(env, jrefOut, reinterpret_cast<jlong>(out));
1052
1053 // release allocated memory
1054 for (int i = 0; i < numMapKeys; i++) {
1055 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
1056 env->ReleaseStringUTFChars(jkey, mapKeys[i]);
1057 env->DeleteLocalRef(jkey);
1058 }
1059 if (mapKeys != NULL) {
1060 delete[] mapKeys;
1061 }
1062
1063 for (int i = 0; i < numProvidedArgShapes; i++) {
1064 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jprovidedArgShapeNames, i));
1065 env->ReleaseStringUTFChars(jkey, providedArgShapeNames[i]);
1066 env->DeleteLocalRef(jkey);
1067 }
1068 if (providedArgShapeNames != NULL) {
1069 delete[] providedArgShapeNames;
1070 }
1071
1072 return ret;
1073 }
1074
Java_org_apache_mxnet_LibInfo_mxExecutorPrint(JNIEnv * env,jobject obj,jlong ptr,jobject debugStr)1075 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorPrint
1076 (JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) {
1077 const char *retDebugStr;
1078 int ret = MXExecutorPrint(reinterpret_cast<ExecutorHandle>(ptr), &retDebugStr);
1079 SetStringField(env, debugStr, retDebugStr);
1080 return ret;
1081 }
1082
ExecutorMonitorCallbackFunc(const char * name,NDArrayHandle arr,void * handle)1083 extern "C" void ExecutorMonitorCallbackFunc
1084 (const char *name, NDArrayHandle arr, void *handle) {
1085 jobject callbackFuncObjGlb = static_cast<jobject>(handle);
1086
1087 JNIEnv *env;
1088 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
1089
1090 // find java callback method
1091 jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb);
1092 jmethodID callbackFunc = env->GetMethodID(callbackClass, "invoke", "(Ljava/lang/String;J)V");
1093
1094 // invoke java callback method
1095 jstring jname = env->NewStringUTF(name);
1096 env->CallVoidMethod(callbackFuncObjGlb, callbackFunc, jname, reinterpret_cast<jlong>(arr));
1097 env->DeleteLocalRef(jname);
1098
1099 env->DeleteLocalRef(callbackClass);
1100 // FIXME(Yizhi): This function can be called multiple times,
1101 // can we find a way to safely destroy this global ref ?
1102 // env->DeleteGlobalRef(callbackFuncObjGlb);
1103 }
Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallback(JNIEnv * env,jobject obj,jlong executorPtr,jobject callbackFuncObj)1104 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallback
1105 (JNIEnv *env, jobject obj, jlong executorPtr, jobject callbackFuncObj) {
1106 jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj);
1107 return MXExecutorSetMonitorCallback(reinterpret_cast<ExecutorHandle>(executorPtr),
1108 ExecutorMonitorCallbackFunc,
1109 reinterpret_cast<void *>(callbackFuncObjGlb));
1110 }
1111
Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv * env,jobject obj)1112 JNIEXPORT jstring JNICALL Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) {
1113 return env->NewStringUTF(MXGetLastError());
1114 }
1115
1116 // IO funcs
Java_org_apache_mxnet_LibInfo_mxListDataIters(JNIEnv * env,jobject obj,jobject creators)1117 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxListDataIters
1118 (JNIEnv * env, jobject obj, jobject creators) {
1119 jclass longCls = env->FindClass("java/lang/Long");
1120 jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
1121
1122 // scala.collection.mutable.ListBuffer append method
1123 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1124 jmethodID listAppend = env->GetMethodID(listClass,
1125 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1126
1127 // Get function list
1128 DataIterCreator *outArray;
1129 mx_uint outSize;
1130 int ret = MXListDataIters(&outSize, &outArray);
1131 for (size_t i = 0; i < outSize; ++i) {
1132 env->CallObjectMethod(creators, listAppend,
1133 env->NewObject(longCls, longConst,
1134 reinterpret_cast<uint64_t>(outArray[i])));
1135 }
1136 return ret;
1137 }
1138
Java_org_apache_mxnet_LibInfo_mxDataIterCreateIter(JNIEnv * env,jobject obj,jlong creator,jobjectArray jkeys,jobjectArray jvals,jobject dataIterHandleRef)1139 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterCreateIter
1140 (JNIEnv * env, jobject obj, jlong creator, jobjectArray jkeys,
1141 jobjectArray jvals, jobject dataIterHandleRef) {
1142 // keys and values
1143 int paramSize = env->GetArrayLength(jkeys);
1144 const char** keys = new const char*[paramSize];
1145 const char** vals = new const char*[paramSize];
1146 jstring jkey, jval;
1147 // use strcpy and release char* created by JNI inplace
1148 for (int i = 0; i < paramSize; i++) {
1149 jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1150 const char* ckey = env->GetStringUTFChars(jkey, 0);
1151 keys[i] = ckey;
1152 env->DeleteLocalRef(jkey);
1153
1154 jval = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
1155 const char* cval = env->GetStringUTFChars(jval, 0);
1156 vals[i] = cval;
1157 env->DeleteLocalRef(jval);
1158 }
1159
1160 // create iter
1161 DataIterHandle out;
1162 int ret = MXDataIterCreateIter(reinterpret_cast<DataIterCreator>(creator),
1163 static_cast<mx_uint>(paramSize),
1164 static_cast<const char**>(keys),
1165 static_cast<const char**>(vals),
1166 &out);
1167 SetLongField(env, dataIterHandleRef, reinterpret_cast<jlong>(out));
1168
1169 // release keys and vals
1170 for (int i = 0; i < paramSize; i++) {
1171 jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1172 env->ReleaseStringUTFChars(key, keys[i]);
1173 env->DeleteLocalRef(key);
1174
1175 jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(jvals, i));
1176 env->ReleaseStringUTFChars(value, vals[i]);
1177 env->DeleteLocalRef(value);
1178 }
1179 delete[] keys;
1180 delete[] vals;
1181
1182 return ret;
1183 }
1184
Java_org_apache_mxnet_LibInfo_mxDataIterGetIterInfo(JNIEnv * env,jobject obj,jlong creator,jobject jname,jobject jdesc,jobject jargNames,jobject jargTypeInfos,jobject jargDescs)1185 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetIterInfo
1186 (JNIEnv * env, jobject obj, jlong creator, jobject jname,
1187 jobject jdesc, jobject jargNames, jobject jargTypeInfos, jobject jargDescs) {
1188 const char* name;
1189 const char* description;
1190 mx_uint numArgs;
1191 const char** argNames;
1192 const char** argTypeInfos;
1193 const char** argDescs;
1194 int ret = MXDataIterGetIterInfo(reinterpret_cast<DataIterCreator>(creator),
1195 &name,
1196 &description,
1197 &numArgs,
1198 &argNames,
1199 &argTypeInfos,
1200 &argDescs);
1201
1202 jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
1203 jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
1204 // set params
1205 env->SetObjectField(jname, valueStr, env->NewStringUTF(name));
1206 env->SetObjectField(jdesc, valueStr, env->NewStringUTF(description));
1207 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1208 jmethodID listAppend = env->GetMethodID(listClass,
1209 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1210 for (size_t i = 0; i < numArgs; i++) {
1211 env->CallObjectMethod(jargNames, listAppend, env->NewStringUTF(argNames[i]));
1212 env->CallObjectMethod(jargTypeInfos, listAppend, env->NewStringUTF(argTypeInfos[i]));
1213 env->CallObjectMethod(jargDescs, listAppend, env->NewStringUTF(argDescs[i]));
1214 }
1215 return ret;
1216 }
1217
Java_org_apache_mxnet_LibInfo_mxDataIterFree(JNIEnv * env,jobject obj,jlong handle)1218 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterFree
1219 (JNIEnv *env, jobject obj, jlong handle) {
1220 int ret = MXDataIterFree(reinterpret_cast<DataIterHandle>(handle));
1221 return ret;
1222 }
1223
Java_org_apache_mxnet_LibInfo_mxDataIterBeforeFirst(JNIEnv * env,jobject obj,jlong handle)1224 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterBeforeFirst
1225 (JNIEnv *env, jobject obj, jlong handle) {
1226 int ret = MXDataIterBeforeFirst(reinterpret_cast<DataIterHandle>(handle));
1227 return ret;
1228 }
1229
Java_org_apache_mxnet_LibInfo_mxDataIterNext(JNIEnv * env,jobject obj,jlong handle,jobject out)1230 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterNext
1231 (JNIEnv *env, jobject obj, jlong handle, jobject out) {
1232 int cout;
1233 int ret = MXDataIterNext(reinterpret_cast<DataIterHandle>(handle), &cout);
1234 SetIntField(env, out, cout);
1235 return ret;
1236 }
1237
Java_org_apache_mxnet_LibInfo_mxDataIterGetLabel(JNIEnv * env,jobject obj,jlong handle,jobject ndArrayHandleRef)1238 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetLabel
1239 (JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
1240 NDArrayHandle out;
1241 int ret = MXDataIterGetLabel(reinterpret_cast<DataIterHandle>(handle), &out);
1242 SetLongField(env, ndArrayHandleRef, reinterpret_cast<jlong>(out));
1243 return ret;
1244 }
1245
Java_org_apache_mxnet_LibInfo_mxDataIterGetData(JNIEnv * env,jobject obj,jlong handle,jobject ndArrayHandleRef)1246 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetData
1247 (JNIEnv *env, jobject obj, jlong handle, jobject ndArrayHandleRef) {
1248 NDArrayHandle out;
1249 int ret = MXDataIterGetData(reinterpret_cast<DataIterHandle>(handle), &out);
1250 SetLongField(env, ndArrayHandleRef, reinterpret_cast<jlong>(out));
1251 return ret;
1252 }
1253
Java_org_apache_mxnet_LibInfo_mxDataIterGetIndex(JNIEnv * env,jobject obj,jlong handle,jobject outIndex,jobject outSize)1254 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetIndex
1255 (JNIEnv *env, jobject obj, jlong handle, jobject outIndex, jobject outSize) {
1256 uint64_t* coutIndex;
1257 uint64_t coutSize;
1258 int ret = MXDataIterGetIndex(reinterpret_cast<DataIterHandle>(handle), &coutIndex, &coutSize);
1259 // set field
1260 SetLongField(env, outSize, static_cast<jlong>(coutSize));
1261 // scala.collection.mutable.ListBuffer append method
1262 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1263 jmethodID listAppend = env->GetMethodID(listClass,
1264 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1265
1266 // long class
1267 jclass longCls = env->FindClass("java/lang/Long");
1268 jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
1269
1270 for (size_t i = 0; i < coutSize; i++) {
1271 env->CallObjectMethod(outIndex, listAppend,
1272 env->NewObject(longCls, longConst, coutIndex[i]));
1273 }
1274 return ret;
1275 }
1276
Java_org_apache_mxnet_LibInfo_mxDataIterGetPadNum(JNIEnv * env,jobject obj,jlong handle,jobject pad)1277 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDataIterGetPadNum
1278 (JNIEnv *env, jobject obj, jlong handle, jobject pad) {
1279 int cpad;
1280 int ret = MXDataIterGetPadNum((DataIterHandle)handle, &cpad);
1281 SetIntField(env, pad, cpad);
1282 return ret;
1283 }
1284
1285 // Symbol functions
Java_org_apache_mxnet_LibInfo_mxSymbolFree(JNIEnv * env,jobject obj,jlong ptr)1286 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolFree
1287 (JNIEnv * env, jobject obj, jlong ptr) {
1288 return MXSymbolFree((SymbolHandle) ptr);
1289 }
1290
Java_org_apache_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators(JNIEnv * env,jobject obj,jobject symbolList)1291 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators
1292 (JNIEnv *env, jobject obj, jobject symbolList) {
1293 mx_uint outSize;
1294 AtomicSymbolCreator *outArray;
1295 int ret = MXSymbolListAtomicSymbolCreators(&outSize, &outArray);
1296
1297 jclass longCls = env->FindClass("java/lang/Long");
1298 jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
1299
1300 jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer");
1301 jmethodID listAppend = env->GetMethodID(listCls,
1302 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1303
1304 for (size_t i = 0; i < outSize; ++i) {
1305 env->CallObjectMethod(symbolList, listAppend,
1306 env->NewObject(longCls, longConst, outArray[i]));
1307 }
1308
1309 return ret;
1310 }
1311
Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo(JNIEnv * env,jobject obj,jlong symbolPtr,jobject name,jobject desc,jobject numArgs,jobject argNames,jobject argTypes,jobject argDescs,jobject keyVarNumArgs)1312 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo
1313 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject name, jobject desc, jobject numArgs,
1314 jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs) {
1315
1316 const char *cName;
1317 const char *cDesc;
1318 mx_uint cNumArgs;
1319 const char **cArgNames;
1320 const char **cArgTypes;
1321 const char **cArgDescs;
1322 const char *cKeyVarNumArgs;
1323
1324 int ret = MXSymbolGetAtomicSymbolInfo(reinterpret_cast<AtomicSymbolCreator>(symbolPtr),
1325 &cName, &cDesc, &cNumArgs,
1326 &cArgNames, &cArgTypes, &cArgDescs,
1327 &cKeyVarNumArgs);
1328
1329 jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
1330 jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
1331
1332 jclass refStringClass = env->FindClass("org/apache/mxnet/Base$RefString");
1333 jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
1334
1335 // scala.collection.mutable.ListBuffer append method
1336 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1337 jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
1338 "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1339
1340 env->SetObjectField(name, valueStr, env->NewStringUTF(cName));
1341 env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc));
1342 env->SetObjectField(keyVarNumArgs, valueStr, env->NewStringUTF(cKeyVarNumArgs));
1343 env->SetIntField(numArgs, valueInt, static_cast<jint>(cNumArgs));
1344 for (size_t i = 0; i < cNumArgs; ++i) {
1345 env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i]));
1346 env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i]));
1347 env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i]));
1348 }
1349
1350 return ret;
1351 }
1352
Java_org_apache_mxnet_LibInfo_mxSymbolCreateAtomicSymbol(JNIEnv * env,jobject obj,jlong symbolPtr,jobjectArray paramKeys,jobjectArray paramVals,jobject symbolRef)1353 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateAtomicSymbol
1354 (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray paramKeys,
1355 jobjectArray paramVals, jobject symbolRef) {
1356 int paramSize = env->GetArrayLength(paramKeys);
1357 const char **keys = new const char*[paramSize];
1358 const char **vals = new const char*[paramSize];
1359 for (int i = 0; i < paramSize; i++) {
1360 jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
1361 const char *rawKey = env->GetStringUTFChars(key, 0);
1362 keys[i] = rawKey;
1363 env->DeleteLocalRef(key);
1364
1365 jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
1366 const char *rawValue = env->GetStringUTFChars(value, 0);
1367 vals[i] = rawValue;
1368 env->DeleteLocalRef(value);
1369 }
1370
1371 SymbolHandle out;
1372 int ret = MXSymbolCreateAtomicSymbol(reinterpret_cast<AtomicSymbolCreator>(symbolPtr),
1373 static_cast<mx_uint>(paramSize), keys, vals, &out);
1374 SetLongField(env, symbolRef, reinterpret_cast<jlong>(out));
1375
1376 // release keys and vals
1377 for (int i = 0; i < paramSize; i++) {
1378 jstring key = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramKeys, i));
1379 env->ReleaseStringUTFChars(key, keys[i]);
1380 env->DeleteLocalRef(key);
1381
1382 jstring value = reinterpret_cast<jstring>(env->GetObjectArrayElement(paramVals, i));
1383 env->ReleaseStringUTFChars(value, vals[i]);
1384 env->DeleteLocalRef(value);
1385 }
1386 delete[] keys;
1387 delete[] vals;
1388
1389 return ret;
1390 }
1391
Java_org_apache_mxnet_LibInfo_mxSymbolSetAttr(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jkey,jstring jvalue)1392 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSetAttr
1393 (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jstring jvalue) {
1394 const char *ckey = env->GetStringUTFChars(jkey, 0);
1395 const char *cvalue = env->GetStringUTFChars(jvalue, 0);
1396 int ret = MXSymbolSetAttr(reinterpret_cast<SymbolHandle>(symbolPtr), ckey, cvalue);
1397 env->ReleaseStringUTFChars(jkey, ckey);
1398 env->ReleaseStringUTFChars(jvalue, cvalue);
1399 return ret;
1400 }
1401
Java_org_apache_mxnet_LibInfo_mxSymbolListAttrShallow(JNIEnv * env,jobject obj,jlong symbolPtr,jobject joutSize,jobject jout)1402 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAttrShallow
1403 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject joutSize, jobject jout) {
1404 mx_uint outSize;
1405 const char** out;
1406
1407 int ret = MXSymbolListAttrShallow(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &out);
1408
1409 jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
1410 jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
1411 env->SetIntField(joutSize, valueInt, static_cast<jint>(outSize));
1412
1413 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1414 jmethodID arrayAppend = env->GetMethodID(arrayClass,
1415 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1416 for (size_t i = 0; i < outSize * 2; ++i) {
1417 jstring jtmp = env->NewStringUTF(out[i]);
1418 env->CallObjectMethod(jout, arrayAppend, jtmp);
1419 env->DeleteLocalRef(jtmp);
1420 }
1421
1422 return ret;
1423 }
1424
Java_org_apache_mxnet_LibInfo_mxSymbolListAttr(JNIEnv * env,jobject obj,jlong symbolPtr,jobject joutSize,jobject jout)1425 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAttr
1426 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject joutSize, jobject jout) {
1427 mx_uint outSize;
1428 const char** out;
1429
1430 int ret = MXSymbolListAttr(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &out);
1431
1432 jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt");
1433 jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I");
1434 env->SetIntField(joutSize, valueInt, static_cast<jint>(outSize));
1435
1436 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1437 jmethodID arrayAppend = env->GetMethodID(arrayClass,
1438 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1439 for (size_t i = 0; i < outSize * 2; ++i) {
1440 jstring jtmp = env->NewStringUTF(out[i]);
1441 env->CallObjectMethod(jout, arrayAppend, jtmp);
1442 env->DeleteLocalRef(jtmp);
1443 }
1444
1445 return ret;
1446 }
1447
Java_org_apache_mxnet_LibInfo_mxSymbolCompose(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jname,jobjectArray jkeys,jlongArray jargs)1448 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCompose
1449 (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jname,
1450 jobjectArray jkeys, jlongArray jargs) {
1451 int argSize = env->GetArrayLength(jargs);
1452 const char **keys = NULL;
1453 if (jkeys != NULL) {
1454 keys = new const char*[argSize];
1455 for (int i = 0; i < argSize; i++) {
1456 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1457 const char *key = env->GetStringUTFChars(jkey, 0);
1458 keys[i] = key;
1459 env->DeleteLocalRef(jkey);
1460 }
1461 }
1462 jlong *args = env->GetLongArrayElements(jargs, NULL);
1463 const char *name = env->GetStringUTFChars(jname, 0);
1464 int ret = MXSymbolCompose(reinterpret_cast<SymbolHandle>(symbolPtr),
1465 name, static_cast<mx_uint>(argSize), keys,
1466 reinterpret_cast<SymbolHandle *>(args));
1467 env->ReleaseStringUTFChars(jname, name);
1468 env->ReleaseLongArrayElements(jargs, args, 0);
1469 // release allocated memory
1470 if (jkeys != NULL) {
1471 for (int i = 0; i < argSize; i++) {
1472 jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
1473 env->ReleaseStringUTFChars(jkey, keys[i]);
1474 env->DeleteLocalRef(jkey);
1475 }
1476 delete[] keys;
1477 }
1478 return ret;
1479 }
1480
Java_org_apache_mxnet_LibInfo_mxSymbolCreateVariable(JNIEnv * env,jobject obj,jstring jname,jobject handle)1481 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateVariable
1482 (JNIEnv *env, jobject obj, jstring jname, jobject handle) {
1483 SymbolHandle out;
1484 const char *name = env->GetStringUTFChars(jname, 0);
1485 int ret = MXSymbolCreateVariable(name, &out);
1486 env->ReleaseStringUTFChars(jname, name);
1487 SetLongField(env, handle, reinterpret_cast<jlong>(out));
1488 return ret;
1489 }
1490
Java_org_apache_mxnet_LibInfo_mxSymbolGetAttr(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jkey,jobject retRef,jobject successRef)1491 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAttr
1492 (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jobject retRef, jobject successRef) {
1493 const char *out;
1494 int success;
1495 const char *key = env->GetStringUTFChars(jkey, 0);
1496 int ret = MXSymbolGetAttr(reinterpret_cast<SymbolHandle>(symbolPtr), key, &out, &success);
1497 env->ReleaseStringUTFChars(jkey, key);
1498
1499 SetStringField(env, retRef, out);
1500 SetIntField(env, successRef, success);
1501 return ret;
1502 }
1503
Java_org_apache_mxnet_LibInfo_mxSymbolListArguments(JNIEnv * env,jobject obj,jlong symbolPtr,jobject arguments)1504 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListArguments
1505 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject arguments) {
1506 mx_uint outSize;
1507 const char **outStrArray;
1508 int ret = MXSymbolListArguments(
1509 reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);
1510
1511 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1512 jmethodID arrayAppend = env->GetMethodID(arrayClass,
1513 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1514 for (size_t i = 0; i < outSize; i++) {
1515 jstring argument = env->NewStringUTF(outStrArray[i]);
1516 env->CallObjectMethod(arguments, arrayAppend, argument);
1517 env->DeleteLocalRef(argument);
1518 }
1519
1520 return ret;
1521 }
1522
Java_org_apache_mxnet_LibInfo_mxSymbolListOutputs(JNIEnv * env,jobject obj,jlong symbolPtr,jobject outputs)1523 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListOutputs
1524 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
1525 mx_uint outSize;
1526 const char **outStrArray;
1527 int ret = MXSymbolListOutputs(reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);
1528
1529 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1530 jmethodID arrayAppend = env->GetMethodID(arrayClass,
1531 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1532 for (size_t i = 0; i < outSize; i++) {
1533 jstring output = env->NewStringUTF(outStrArray[i]);
1534 env->CallObjectMethod(outputs, arrayAppend, output);
1535 env->DeleteLocalRef(output);
1536 }
1537
1538 return ret;
1539 }
1540
Java_org_apache_mxnet_LibInfo_mxSymbolListAuxiliaryStates(JNIEnv * env,jobject obj,jlong symbolPtr,jobject outputs)1541 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAuxiliaryStates
1542 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
1543 mx_uint outSize;
1544 const char **outStrArray;
1545 int ret = MXSymbolListAuxiliaryStates(
1546 reinterpret_cast<SymbolHandle>(symbolPtr), &outSize, &outStrArray);
1547
1548 jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
1549 jmethodID arrayAppend = env->GetMethodID(arrayClass,
1550 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
1551 for (size_t i = 0; i < outSize; i++) {
1552 jstring output = env->NewStringUTF(outStrArray[i]);
1553 env->CallObjectMethod(outputs, arrayAppend, output);
1554 env->DeleteLocalRef(output);
1555 }
1556
1557 return ret;
1558 }
1559
Java_org_apache_mxnet_LibInfo_mxSymbolCopy(JNIEnv * env,jobject obj,jlong symbolPtr,jobject clonedSymbolRef)1560 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCopy
1561 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) {
1562 SymbolHandle clonedSymbol;
1563 int ret = MXSymbolCopy(reinterpret_cast<SymbolHandle>(symbolPtr), &clonedSymbol);
1564 SetLongField(env, clonedSymbolRef, reinterpret_cast<jlong>(clonedSymbol));
1565 return ret;
1566 }
1567
Java_org_apache_mxnet_LibInfo_mxSymbolCreateGroup(JNIEnv * env,jobject obj,jlongArray jsymbols,jobject out)1568 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateGroup
1569 (JNIEnv *env, jobject obj, jlongArray jsymbols, jobject out) {
1570 int numSymbols = env->GetArrayLength(jsymbols);
1571 SymbolHandle handle;
1572 jlong *symbols = env->GetLongArrayElements(jsymbols, NULL);
1573 int ret = MXSymbolCreateGroup(numSymbols, reinterpret_cast<SymbolHandle *>(symbols), &handle);
1574 env->ReleaseLongArrayElements(jsymbols, symbols, 0);
1575 SetLongField(env, out, reinterpret_cast<jlong>(handle));
1576 return ret;
1577 }
1578
Java_org_apache_mxnet_LibInfo_mxSymbolPrint(JNIEnv * env,jobject obj,jlong symbolPtr,jobject out)1579 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolPrint
1580 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject out) {
1581 const char *outStr;
1582 int ret = MXSymbolPrint(reinterpret_cast<SymbolHandle>(symbolPtr), &outStr);
1583 SetStringField(env, out, outStr);
1584 return ret;
1585 }
1586
Java_org_apache_mxnet_LibInfo_mxSymbolGetOutput(JNIEnv * env,jobject obj,jlong symbolPtr,jint index,jobject jout)1587 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetOutput
1588 (JNIEnv *env, jobject obj, jlong symbolPtr, jint index, jobject jout) {
1589 SymbolHandle out;
1590 int ret = MXSymbolGetOutput(reinterpret_cast<SymbolHandle>(symbolPtr),
1591 static_cast<mx_uint>(index), &out);
1592 SetLongField(env, jout, reinterpret_cast<jlong>(out));
1593 return ret;
1594 }
1595
Java_org_apache_mxnet_LibInfo_mxSymbolGetInternals(JNIEnv * env,jobject obj,jlong symbolPtr,jobject jout)1596 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetInternals
1597 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) {
1598 SymbolHandle out;
1599 int ret = MXSymbolGetInternals(reinterpret_cast<SymbolHandle>(symbolPtr), &out);
1600 SetLongField(env, jout, reinterpret_cast<jlong>(out));
1601 return ret;
1602 }
1603
Java_org_apache_mxnet_LibInfo_mxSymbolInferType(JNIEnv * env,jobject obj,jlong symbolPtr,jobjectArray jkeys,jintArray jvals,jobject jargTypeData,jobject joutTypeData,jobject jauxTypeData,jobject jcomplete)1604 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferType
1605 (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray jkeys, jintArray jvals,
1606 jobject jargTypeData, jobject joutTypeData, jobject jauxTypeData, jobject jcomplete) {
1607 int numArgs = env->GetArrayLength(jvals);
1608 const char **keys = NULL;
1609 if (jkeys != NULL) {
1610 keys = new const char *[numArgs];
1611 for (int i = 0; i < numArgs; i++) {
1612 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1613 const char *key = env->GetStringUTFChars(jkey, 0);
1614 keys[i] = key;
1615 env->DeleteLocalRef(jkey);
1616 }
1617 }
1618
1619 mx_uint inTypeSize;
1620 const int *inTypeData;
1621 mx_uint outTypeSize;
1622 const int *outTypeData;
1623 mx_uint auxTypeSize;
1624 const int *auxTypeData;
1625 int complete;
1626
1627 jint *vals = env->GetIntArrayElements(jvals, NULL);
1628 int ret = MXSymbolInferType(reinterpret_cast<SymbolHandle>(symbolPtr),
1629 static_cast<mx_uint>(numArgs), keys,
1630 static_cast<const int *>(vals),
1631 &inTypeSize, &inTypeData,
1632 &outTypeSize, &outTypeData,
1633 &auxTypeSize, &auxTypeData,
1634 &complete);
1635 env->ReleaseIntArrayElements(jvals, vals, 0);
1636
1637 jclass integerClass = env->FindClass("java/lang/Integer");
1638 jmethodID newInteger = env->GetMethodID(integerClass, "<init>", "(I)V");
1639
1640 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1641 jmethodID listAppend = env->GetMethodID(listClass,
1642 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1643
1644 for (size_t i = 0; i < inTypeSize; ++i) {
1645 jobject data = env->NewObject(integerClass, newInteger, inTypeData[i]);
1646 env->CallObjectMethod(jargTypeData, listAppend, data);
1647 env->DeleteLocalRef(data);
1648 }
1649 for (size_t i = 0; i < outTypeSize; ++i) {
1650 jobject data = env->NewObject(integerClass, newInteger, outTypeData[i]);
1651 env->CallObjectMethod(joutTypeData, listAppend, data);
1652 env->DeleteLocalRef(data);
1653 }
1654 for (size_t i = 0; i < auxTypeSize; ++i) {
1655 jobject data = env->NewObject(integerClass, newInteger, auxTypeData[i]);
1656 env->CallObjectMethod(jauxTypeData, listAppend, data);
1657 env->DeleteLocalRef(data);
1658 }
1659
1660 SetIntField(env, jcomplete, complete);
1661
1662 // release allocated memory
1663 if (jkeys != NULL) {
1664 for (int i = 0; i < numArgs; i++) {
1665 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1666 env->ReleaseStringUTFChars(jkey, keys[i]);
1667 env->DeleteLocalRef(jkey);
1668 }
1669 delete[] keys;
1670 }
1671
1672 return ret;
1673 }
1674
Java_org_apache_mxnet_LibInfo_mxSymbolSaveToJSON(JNIEnv * env,jobject obj,jlong symbolPtr,jobject jout)1675 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSaveToJSON
1676 (JNIEnv *env, jobject obj, jlong symbolPtr, jobject jout) {
1677 const char *out;
1678 int ret = MXSymbolSaveToJSON(reinterpret_cast<SymbolHandle>(symbolPtr), &out);
1679 SetStringField(env, jout, out);
1680 return ret;
1681 }
1682
Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromJSON(JNIEnv * env,jobject obj,jstring json,jobject jhandleRef)1683 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromJSON
1684 (JNIEnv *env, jobject obj, jstring json, jobject jhandleRef) {
1685 const char *str = env->GetStringUTFChars(json, 0);
1686 SymbolHandle out;
1687 int ret = MXSymbolCreateFromJSON(str, &out);
1688 SetLongField(env, jhandleRef, reinterpret_cast<jlong>(out));
1689 env->ReleaseStringUTFChars(json, str);
1690 return ret;
1691 }
1692
Java_org_apache_mxnet_LibInfo_mxSymbolSaveToFile(JNIEnv * env,jobject obj,jlong symbolPtr,jstring jfname)1693 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolSaveToFile
1694 (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jfname) {
1695 const char *fname = env->GetStringUTFChars(jfname, 0);
1696 int ret = MXSymbolSaveToFile(reinterpret_cast<SymbolHandle>(symbolPtr), fname);
1697 env->ReleaseStringUTFChars(jfname, fname);
1698 return ret;
1699 }
1700
Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile(JNIEnv * env,jobject obj,jstring jfname,jobject jhandleRef)1701 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile
1702 (JNIEnv *env, jobject obj, jstring jfname, jobject jhandleRef) {
1703 const char *fname = env->GetStringUTFChars(jfname, 0);
1704 SymbolHandle out;
1705 int ret = MXSymbolCreateFromFile(fname, &out);
1706 SetLongField(env, jhandleRef, reinterpret_cast<jlong>(out));
1707 env->ReleaseStringUTFChars(jfname, fname);
1708 return ret;
1709 }
1710
FillSymbolInferShape(JNIEnv * env,jmethodID listAppend,jobject joutData,int shapeSize,const int * shapeNdim,const int ** shapeData)1711 int FillSymbolInferShape
1712 (JNIEnv *env, jmethodID listAppend, jobject joutData,
1713 int shapeSize, const int *shapeNdim, const int **shapeData) {
1714 for (int i = 0; i < shapeSize; ++i) {
1715 jintArray jshape = NULL;
1716 if (shapeNdim[i] >= 0) {
1717 jshape = env->NewIntArray(shapeNdim[i]);
1718 if (jshape == NULL) {
1719 // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1720 return -1;
1721 }
1722 env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast<const jint *>(shapeData[i]));
1723 }
1724 env->CallObjectMethod(joutData, listAppend, jshape);
1725 env->DeleteLocalRef(jshape);
1726 }
1727 return 0;
1728 }
1729
SymbolInferShapeHelper(JNIEnv * env,jobject obj,jlong symbolPtr,jint jnumArgs,jobjectArray jkeys,jintArray jargIndPtr,jintArray jargShapeData,jobject jinShapeData,jobject joutShapeData,jobject jauxShapeData,jobject jcomplete,bool partial)1730 int SymbolInferShapeHelper(JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs,
1731 jobjectArray jkeys, jintArray jargIndPtr, jintArray jargShapeData,
1732 jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData,
1733 jobject jcomplete, bool partial) {
1734 const char **keys = NULL;
1735 if (jkeys != NULL) {
1736 keys = new const char *[jnumArgs];
1737 for (int i = 0; i < jnumArgs; i++) {
1738 jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
1739 const char *key = env->GetStringUTFChars(jkey, 0);
1740 keys[i] = key;
1741 env->DeleteLocalRef(jkey);
1742 }
1743 }
1744
1745 mx_uint inShapeSize;
1746 const int *inShapeNdim;
1747 const int **inShapeData;
1748
1749 mx_uint outShapeSize;
1750 const int *outShapeNdim;
1751 const int **outShapeData;
1752
1753 mx_uint auxShapeSize;
1754 const int *auxShapeNdim;
1755 const int **auxShapeData;
1756
1757 int complete;
1758
1759 jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL);
1760 jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL);
1761 int ret;
1762 if (!partial) {
1763 ret = MXSymbolInferShapeEx(reinterpret_cast<SymbolHandle>(symbolPtr),
1764 static_cast<mx_uint>(jnumArgs),
1765 keys,
1766 reinterpret_cast<mx_uint *>(argIndPtr),
1767 reinterpret_cast<const int *>(argShapeData),
1768 &inShapeSize,
1769 &inShapeNdim,
1770 &inShapeData,
1771 &outShapeSize,
1772 &outShapeNdim,
1773 &outShapeData,
1774 &auxShapeSize,
1775 &auxShapeNdim,
1776 &auxShapeData,
1777 &complete);
1778 } else {
1779 ret = MXSymbolInferShapePartialEx(reinterpret_cast<SymbolHandle>(symbolPtr),
1780 static_cast<mx_uint>(jnumArgs),
1781 keys,
1782 reinterpret_cast<mx_uint *>(argIndPtr),
1783 reinterpret_cast<const int *>(argShapeData),
1784 &inShapeSize,
1785 &inShapeNdim,
1786 &inShapeData,
1787 &outShapeSize,
1788 &outShapeNdim,
1789 &outShapeData,
1790 &auxShapeSize,
1791 &auxShapeNdim,
1792 &auxShapeData,
1793 &complete);
1794 }
1795 env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0);
1796 env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0);
1797
1798 if (ret == 0) {
1799 jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
1800 jmethodID listAppend = env->GetMethodID(listClass,
1801 "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
1802
1803 if (FillSymbolInferShape(
1804 env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) {
1805 // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1806 return -1;
1807 }
1808 if (FillSymbolInferShape(
1809 env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) {
1810 // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1811 return -1;
1812 }
1813 if (FillSymbolInferShape(
1814 env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) {
1815 // TODO(Yizhi): out of memory error thrown, return a specific error code ?
1816 return -1;
1817 }
1818
1819 SetIntField(env, jcomplete, complete);
1820 }
1821
1822 // release allocated memory
1823 if (jkeys != NULL) {
1824 for (int i = 0; i < jnumArgs; i++) {
1825 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jkeys, i));
1826 env->ReleaseStringUTFChars(jkey, keys[i]);
1827 env->DeleteLocalRef(jkey);
1828 }
1829 delete[] keys;
1830 }
1831
1832 return ret;
1833 }
1834
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape(JNIEnv * env,jobject obj,jlong symbolPtr,jint jnumArgs,jobjectArray jkeys,jintArray jargIndPtr,jintArray jargShapeData,jobject jinShapeData,jobject joutShapeData,jobject jauxShapeData,jobject jcomplete)1835 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
1836 (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys,
1837 jintArray jargIndPtr, jintArray jargShapeData,
1838 jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) {
1839
1840 return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData,
1841 jinShapeData, joutShapeData, jauxShapeData, jcomplete, false);
1842 }
1843
Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial(JNIEnv * env,jobject obj,jlong symbolPtr,jint jnumArgs,jobjectArray jkeys,jintArray jargIndPtr,jintArray jargShapeData,jobject jinShapeData,jobject joutShapeData,jobject jauxShapeData,jobject jcomplete)1844 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial
1845 (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray jkeys,
1846 jintArray jargIndPtr, jintArray jargShapeData,
1847 jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, jobject jcomplete) {
1848
1849 return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, jargIndPtr, jargShapeData,
1850 jinShapeData, joutShapeData, jauxShapeData, jcomplete, true);
1851 }
1852
Java_org_apache_mxnet_LibInfo_mxExecutorBindX(JNIEnv * env,jobject obj,jlong symbolPtr,jint deviceTypeId,jint deviceID,jint numCtx,jobjectArray jctxMapKeys,jintArray jctxMapDevTypes,jintArray jctxMapDevIDs,jint numArgs,jlongArray jargsHandle,jlongArray jargsGradHandle,jintArray jreqsArray,jlongArray jauxArgsHandle,jobject jexecOut)1853 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindX
1854 (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx,
1855 jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs,
1856 jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray,
1857 jlongArray jauxArgsHandle, jobject jexecOut) {
1858 ExecutorHandle out;
1859 int auxStatesLen = env->GetArrayLength(jauxArgsHandle);
1860
1861 const char **mapKeys = new const char *[numCtx];
1862 for (int i = 0; i < numCtx; i++) {
1863 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jctxMapKeys, i));
1864 const char *key = env->GetStringUTFChars(jkey, 0);
1865 mapKeys[i] = key;
1866 env->DeleteLocalRef(jkey);
1867 }
1868 jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL);
1869 jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL);
1870 jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL);
1871 jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL);
1872 jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL);
1873 jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL);
1874 int ret = MXExecutorBindX(reinterpret_cast<SymbolHandle>(symbolPtr),
1875 deviceTypeId,
1876 deviceID,
1877 static_cast<mx_uint>(numCtx),
1878 mapKeys,
1879 mapDevTypes,
1880 mapDevIDs,
1881 static_cast<mx_uint>(numArgs),
1882 reinterpret_cast<NDArrayHandle *>(inArgs),
1883 reinterpret_cast<NDArrayHandle *>(argGradStore),
1884 reinterpret_cast<mx_uint *>(gradReqType),
1885 static_cast<mx_uint>(auxStatesLen),
1886 reinterpret_cast<NDArrayHandle *>(auxStates),
1887 &out);
1888 env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0);
1889 env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0);
1890 env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0);
1891 env->ReleaseLongArrayElements(jargsHandle, inArgs, 0);
1892 env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0);
1893 env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0);
1894 for (int i = 0; i < numCtx; i++) {
1895 jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
1896 env->ReleaseStringUTFChars(jkey, mapKeys[i]);
1897 env->DeleteLocalRef(jkey);
1898 }
1899 delete[] mapKeys;
1900
1901 SetLongField(env, jexecOut, reinterpret_cast<jlong>(out));
1902 return ret;
1903 }
1904
Java_org_apache_mxnet_LibInfo_mxExecutorBindEX(JNIEnv * env,jobject obj,jlong symbolPtr,jint deviceTypeId,jint deviceID,jint numCtx,jobjectArray jctxMapKeys,jintArray jctxMapDevTypes,jintArray jctxMapDevIDs,jint numArgs,jlongArray jargsHandle,jlongArray jargsGradHandle,jintArray jreqsArray,jlongArray jauxArgsHandle,jlong jsharedExec,jobject jexecOut)1905 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindEX
1906 (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint deviceID, jint numCtx,
1907 jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray jctxMapDevIDs, jint numArgs,
1908 jlongArray jargsHandle, jlongArray jargsGradHandle, jintArray jreqsArray,
1909 jlongArray jauxArgsHandle, jlong jsharedExec, jobject jexecOut) {
1910 ExecutorHandle out;
1911 int auxStatesLen = env->GetArrayLength(jauxArgsHandle);
1912 ExecutorHandle sharedExec = nullptr;
1913 if ((int32_t)jsharedExec != 0) sharedExec = reinterpret_cast<ExecutorHandle>(jsharedExec);
1914
1915 const char **mapKeys = new const char *[numCtx];
1916 for (int i = 0; i < numCtx; i++) {
1917 jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(jctxMapKeys, i));
1918 const char *key = env->GetStringUTFChars(jkey, 0);
1919 mapKeys[i] = key;
1920 env->DeleteLocalRef(jkey);
1921 }
1922 jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL);
1923 jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL);
1924 jlong *inArgs = env->GetLongArrayElements(jargsHandle, NULL);
1925 jlong *argGradStore = env->GetLongArrayElements(jargsGradHandle, NULL);
1926 jint *mapDevTypes = env->GetIntArrayElements(jctxMapDevTypes, NULL);
1927 jint *mapDevIDs = env->GetIntArrayElements(jctxMapDevIDs, NULL);
1928 int ret = MXExecutorBindEX(reinterpret_cast<SymbolHandle>(symbolPtr),
1929 deviceTypeId,
1930 deviceID,
1931 static_cast<mx_uint>(numCtx),
1932 mapKeys,
1933 mapDevTypes,
1934 mapDevIDs,
1935 static_cast<mx_uint>(numArgs),
1936 reinterpret_cast<NDArrayHandle *>(inArgs),
1937 reinterpret_cast<NDArrayHandle *>(argGradStore),
1938 reinterpret_cast<mx_uint *>(gradReqType),
1939 static_cast<mx_uint>(auxStatesLen),
1940 reinterpret_cast<NDArrayHandle *>(auxStates),
1941 sharedExec,
1942 &out);
1943 env->ReleaseIntArrayElements(jctxMapDevIDs, mapDevIDs, 0);
1944 env->ReleaseIntArrayElements(jctxMapDevTypes, mapDevTypes, 0);
1945 env->ReleaseLongArrayElements(jargsGradHandle, argGradStore, 0);
1946 env->ReleaseLongArrayElements(jargsHandle, inArgs, 0);
1947 env->ReleaseIntArrayElements(jreqsArray, gradReqType, 0);
1948 env->ReleaseLongArrayElements(jauxArgsHandle, auxStates, 0);
1949 for (int i = 0; i < numCtx; i++) {
1950 jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i);
1951 env->ReleaseStringUTFChars(jkey, mapKeys[i]);
1952 env->DeleteLocalRef(jkey);
1953 }
1954 delete[] mapKeys;
1955
1956 SetLongField(env, jexecOut, reinterpret_cast<jlong>(out));
1957 return ret;
1958 }
1959
Java_org_apache_mxnet_LibInfo_mxRandomSeed(JNIEnv * env,jobject obj,jint seed)1960 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRandomSeed
1961 (JNIEnv *env, jobject obj, jint seed) {
1962 return MXRandomSeed(seed);
1963 }
1964
Java_org_apache_mxnet_LibInfo_mxNotifyShutdown(JNIEnv * env,jobject obj)1965 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNotifyShutdown
1966 (JNIEnv *env, jobject obj) {
1967 return MXNotifyShutdown();
1968 }
1969
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterCreate(JNIEnv * env,jobject obj,jstring juri,jobject handle)1970 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterCreate
1971 (JNIEnv *env, jobject obj, jstring juri, jobject handle) {
1972 RecordIOHandle out;
1973 const char *uri = env->GetStringUTFChars(juri, 0);
1974 int ret = MXRecordIOWriterCreate(uri, &out);
1975 env->ReleaseStringUTFChars(juri, uri);
1976 SetLongField(env, handle, reinterpret_cast<jlong>(out));
1977 return ret;
1978 }
1979
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderCreate(JNIEnv * env,jobject obj,jstring juri,jobject handle)1980 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderCreate
1981 (JNIEnv *env, jobject obj, jstring juri, jobject handle) {
1982 RecordIOHandle out;
1983 const char *uri = env->GetStringUTFChars(juri, 0);
1984 int ret = MXRecordIOReaderCreate(uri, &out);
1985 env->ReleaseStringUTFChars(juri, uri);
1986 SetLongField(env, handle, reinterpret_cast<jlong>(out));
1987 return ret;
1988 }
1989
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterFree(JNIEnv * env,jobject obj,jlong handle)1990 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterFree
1991 (JNIEnv *env, jobject obj, jlong handle) {
1992 RecordIOHandle recordIOHandle = reinterpret_cast<RecordIOHandle>(handle);
1993 int ret = MXRecordIOWriterFree(recordIOHandle);
1994 return ret;
1995 }
1996
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderFree(JNIEnv * env,jobject obj,jlong handle)1997 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderFree
1998 (JNIEnv *env, jobject obj, jlong handle) {
1999 RecordIOHandle recordIOHandle = reinterpret_cast<RecordIOHandle>(handle);
2000 int ret = MXRecordIOReaderFree(&recordIOHandle);
2001 return ret;
2002 }
2003
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterWriteRecord(JNIEnv * env,jobject obj,jlong handle,jstring jbuf,jint size)2004 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterWriteRecord
2005 (JNIEnv *env, jobject obj, jlong handle, jstring jbuf, jint size) {
2006 const char *buf = env->GetStringUTFChars(jbuf, 0);
2007 RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2008 int ret = MXRecordIOWriterWriteRecord(recordIOHandle, buf, size);
2009 env->ReleaseStringUTFChars(jbuf, buf);
2010 return ret;
2011 }
2012
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderReadRecord(JNIEnv * env,jobject obj,jlong handle,jobject buf)2013 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderReadRecord
2014 (JNIEnv *env, jobject obj, jlong handle, jobject buf) {
2015 RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2016 size_t size;
2017 char const *out;
2018 int ret = MXRecordIOReaderReadRecord(recordIOHandle, &out, &size);
2019 SetStringField(env, buf, out);
2020 return ret;
2021 }
2022
Java_org_apache_mxnet_LibInfo_mxRecordIOWriterTell(JNIEnv * env,jobject obj,jlong handle,jobject jpos)2023 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOWriterTell
2024 (JNIEnv *env, jobject obj, jlong handle, jobject jpos) {
2025 RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2026 size_t pos;
2027 int ret = MXRecordIOWriterTell(recordIOHandle, &pos);
2028 SetIntField(env, jpos, pos);
2029 return ret;
2030 }
2031
Java_org_apache_mxnet_LibInfo_mxRecordIOReaderSeek(JNIEnv * env,jobject obj,jlong handle,jint pos)2032 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRecordIOReaderSeek
2033 (JNIEnv *env, jobject obj, jlong handle, jint pos) {
2034 RecordIOHandle *recordIOHandle = reinterpret_cast<RecordIOHandle *>(handle);
2035 int ret = MXRecordIOReaderSeek(recordIOHandle, pos);
2036 return ret;
2037 }
2038
Java_org_apache_mxnet_LibInfo_mxRtcCreate(JNIEnv * env,jobject obj,jstring jname,jobjectArray jinputNames,jobjectArray joutputNames,jlongArray jinputs,jlongArray joutputs,jstring jkernel,jobject jhandle)2039 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcCreate
2040 (JNIEnv *env, jobject obj, jstring jname, jobjectArray jinputNames,
2041 jobjectArray joutputNames, jlongArray jinputs, jlongArray joutputs,
2042 jstring jkernel, jobject jhandle) {
2043 RtcHandle out;
2044 char *name = const_cast<char *>(env->GetStringUTFChars(jname, 0));
2045 int num_input = env->GetArrayLength(jinputNames);
2046 char **inputNames = new char *[num_input];
2047 for (int i = 0; i < num_input; i++) {
2048 jstring jinname = reinterpret_cast<jstring>(env->GetObjectArrayElement(jinputNames, i));
2049 char *inname = const_cast<char *>(env->GetStringUTFChars(jinname, 0));
2050 inputNames[i] = inname;
2051 env->DeleteLocalRef(jinname);
2052 }
2053 int num_output = env->GetArrayLength(joutputNames);
2054 char **outputNames = new char *[num_output];
2055 for (int i = 0; i < num_output; i++) {
2056 jstring joutname = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputNames, i));
2057 char *outname = const_cast<char *>(env->GetStringUTFChars(joutname, 0));
2058 outputNames[i] = outname;
2059 env->DeleteLocalRef(joutname);
2060 }
2061 jlong *inputs = env->GetLongArrayElements(jinputs, NULL);
2062 jlong *outputs = env->GetLongArrayElements(joutputs, NULL);
2063 char *kernel = const_cast<char *>(env->GetStringUTFChars(jkernel, 0));
2064
2065 int ret = MXRtcCreate(name,
2066 static_cast<mx_uint>(num_input),
2067 static_cast<mx_uint>(num_output),
2068 inputNames,
2069 outputNames,
2070 reinterpret_cast<NDArrayHandle *>(inputs),
2071 reinterpret_cast<NDArrayHandle *>(outputs),
2072 kernel,
2073 &out);
2074
2075 // release allocated memory
2076 env->ReleaseStringUTFChars(jname, name);
2077 env->ReleaseStringUTFChars(jkernel, kernel);
2078 env->ReleaseLongArrayElements(jinputs, inputs, 0);
2079 env->ReleaseLongArrayElements(joutputs, outputs, 0);
2080 for (int i = 0; i < num_input; i++) {
2081 jstring jinname = reinterpret_cast<jstring>(env->GetObjectArrayElement(jinputNames, i));
2082 env->ReleaseStringUTFChars(jinname, inputNames[i]);
2083 env->DeleteLocalRef(jinname);
2084 }
2085 delete[] inputNames;
2086 for (int i = 0; i < num_output; i++) {
2087 jstring joutname = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputNames, i));
2088 env->ReleaseStringUTFChars(joutname, outputNames[i]);
2089 env->DeleteLocalRef(joutname);
2090 }
2091 delete[] outputNames;
2092
2093 SetLongField(env, jhandle, reinterpret_cast<jlong>(out));
2094 return ret;
2095 }
2096
Java_org_apache_mxnet_LibInfo_mxRtcPush(JNIEnv * env,jobject obj,jlong jhandle,jlongArray jinputs,jlongArray joutputs,jint gridDimX,jint gridDimY,jint gridDimZ,jint blockDimX,jint blockDimY,jint blockDimZ)2097 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcPush
2098 (JNIEnv *env, jobject obj, jlong jhandle, jlongArray jinputs,
2099 jlongArray joutputs, jint gridDimX, jint gridDimY, jint gridDimZ,
2100 jint blockDimX, jint blockDimY, jint blockDimZ) {
2101
2102 RtcHandle handle = reinterpret_cast<RtcHandle>(jhandle);
2103 jlong *inputs = env->GetLongArrayElements(jinputs, NULL);
2104 jlong *outputs = env->GetLongArrayElements(joutputs, NULL);
2105 int num_input = env->GetArrayLength(jinputs);
2106 int num_output = env->GetArrayLength(joutputs);
2107
2108 int ret = MXRtcPush(handle,
2109 static_cast<mx_uint>(num_input),
2110 static_cast<mx_uint>(num_output),
2111 reinterpret_cast<NDArrayHandle *>(inputs),
2112 reinterpret_cast<NDArrayHandle *>(outputs),
2113 static_cast<mx_uint>(gridDimX),
2114 static_cast<mx_uint>(gridDimY),
2115 static_cast<mx_uint>(gridDimZ),
2116 static_cast<mx_uint>(blockDimX),
2117 static_cast<mx_uint>(blockDimY),
2118 static_cast<mx_uint>(blockDimZ));
2119
2120 // release allocated memory
2121 env->ReleaseLongArrayElements(jinputs, inputs, 0);
2122 env->ReleaseLongArrayElements(joutputs, outputs, 0);
2123
2124 return ret;
2125 }
2126
Java_org_apache_mxnet_LibInfo_mxRtcFree(JNIEnv * env,jobject obj,jlong jhandle)2127 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcFree
2128 (JNIEnv *env, jobject obj, jlong jhandle) {
2129 RtcHandle handle = reinterpret_cast<RtcHandle>(jhandle);
2130 int ret = MXRtcFree(handle);
2131 return ret;
2132 }
2133
2134 // store the user defined CustomOpProp object reference with its name
2135 std::unordered_map<std::string, jobject> globalOpPropMap;
2136 // store the user defined CustomOp object reference with its name
2137 std::unordered_map<std::string, jobject> globalOpMap;
2138 // used for thread safty when insert elements into
2139 // or erase elements from the std::unordered_map
2140 std::mutex mutex_opprop;
2141 std::mutex mutex_op;
2142
2143 // Registers a custom operator when called
Java_org_apache_mxnet_LibInfo_mxCustomOpRegister(JNIEnv * env,jobject obj,jstring jregName,jobject jopProp)2144 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
2145 (JNIEnv *env, jobject obj, jstring jregName, jobject jopProp) {
2146 const char *regName = env->GetStringUTFChars(jregName, 0);
2147 std::string key(regName);
2148
2149 std::unique_lock<std::mutex> lock(mutex_opprop);
2150 globalOpPropMap.insert({ key, env->NewGlobalRef(jopProp) });
2151 lock.unlock();
2152
2153 // lambda function to initialize the operator and create all callbacks
2154 auto creatorLambda = [](const char *opType, const int numKwargs,
2155 const char **keys, const char **values, MXCallbackList *ret) {
2156 int success = true;
2157
2158 std::string opPropKey(opType);
2159 if (globalOpPropMap.find(opPropKey) == globalOpPropMap.end()) {
2160 LOG(WARNING) << "CustomOpProp: " << opPropKey << " not found";
2161 success = false;
2162 } else {
2163 JNIEnv *env;
2164 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2165 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(opPropKey));
2166 jmethodID midInit = env->GetMethodID(opPropClass,
2167 "init", "([Ljava/lang/String;[Ljava/lang/String;)V");
2168 if (NULL == midInit) {
2169 LOG(WARNING) << "could not find CustomOpProp method init.";
2170 success = false;
2171 } else {
2172 // call init and set CustomOpProp.kwargs
2173 jclass strCls = env->FindClass("Ljava/lang/String;");
2174 jobjectArray keysArr = env->NewObjectArray(numKwargs, strCls, NULL);
2175 jobjectArray valuesArr = env->NewObjectArray(numKwargs, strCls, NULL);
2176 for (int i = 0; i < numKwargs; ++i) {
2177 jstring keyStr = env->NewStringUTF(keys[i]);
2178 jstring valueStr = env->NewStringUTF(values[i]);
2179 env->SetObjectArrayElement(keysArr, i, keyStr);
2180 env->SetObjectArrayElement(valuesArr, i, valueStr);
2181 env->DeleteLocalRef(keyStr);
2182 env->DeleteLocalRef(valueStr);
2183 }
2184 env->CallVoidMethod(globalOpPropMap.at(opPropKey), midInit, keysArr, valuesArr);
2185 env->DeleteLocalRef(keysArr);
2186 env->DeleteLocalRef(valuesArr);
2187 }
2188 _jvm->DetachCurrentThread();
2189 }
2190
2191 // list_arguments callback
2192 auto opPropListArgument = [](char ***args, void *state) {
2193 int success = true;
2194 std::string key(reinterpret_cast<char *>(state));
2195 if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2196 LOG(WARNING) << "CustomOpProp: " << key << " not found";
2197 success = false;
2198 } else {
2199 JNIEnv *env;
2200 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2201 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2202 jmethodID midListArguments = env->GetMethodID(
2203 opPropClass, "listArguments", "()[Ljava/lang/String;");
2204 if (NULL == midListArguments) {
2205 LOG(WARNING) << "could not find opProp method listArguments.";
2206 success = false;
2207 } else {
2208 jobjectArray jargs =(jobjectArray)(env->CallObjectMethod(
2209 globalOpPropMap.at(key), midListArguments));
2210 int len = env->GetArrayLength(jargs);
2211 *args = new char *[len+1];
2212 for (int i = 0; i < len; ++i) {
2213 jstring jarg = reinterpret_cast<jstring>(env->GetObjectArrayElement(jargs, i));
2214 const char *arg = env->GetStringUTFChars(jarg, 0);
2215 (*args)[i] = const_cast<char *>(arg);
2216 env->DeleteLocalRef(jarg);
2217 }
2218 (*args)[len] = NULL;
2219 }
2220 _jvm->DetachCurrentThread();
2221 }
2222 return success;
2223 };
2224
2225 // list_outputs callback
2226 auto opPropListOutputs = [](char ***outputs, void *state) {
2227 int success = true;
2228 std::string key(reinterpret_cast<char *>(state));
2229 if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2230 LOG(WARNING) << "CustomOpProp: " << key << " not found";
2231 success = false;
2232 } else {
2233 JNIEnv *env;
2234 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2235 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2236 jmethodID midListOutputs = env->GetMethodID(
2237 opPropClass, "listOutputs", "()[Ljava/lang/String;");
2238 if (NULL == midListOutputs) {
2239 LOG(WARNING) << "could not find opProp method listOutputs.";
2240 success = false;
2241 } else {
2242 jobjectArray joutputs = (jobjectArray)(env->CallObjectMethod(
2243 globalOpPropMap.at(key), midListOutputs));
2244 int len = env->GetArrayLength(joutputs);
2245 *outputs = new char *[len + 1];
2246 for (int i = 0; i < len; ++i) {
2247 jstring joutput = reinterpret_cast<jstring>(env->GetObjectArrayElement(joutputs, i));
2248 const char *output = env->GetStringUTFChars(joutput, 0);
2249 (*outputs)[i] = const_cast<char *>(output);
2250 env->DeleteLocalRef(joutput);
2251 }
2252 (*outputs)[len] = NULL;
2253 }
2254 _jvm->DetachCurrentThread();
2255 }
2256 return success;
2257 };
2258
2259 // list_auxiliary_states callback
2260 auto opPropListAuxStates = [](char ***auxs, void *state) {
2261 int success = true;
2262 std::string key(reinterpret_cast<char *>(state));
2263 if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2264 LOG(WARNING) << "CustomOpProp: " << key << " not found";
2265 success = false;
2266 } else {
2267 JNIEnv *env;
2268 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2269 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2270 jmethodID midListAuxStates = env->GetMethodID(
2271 opPropClass, "listAuxiliaryStates", "()[Ljava/lang/String;");
2272 if (NULL == midListAuxStates) {
2273 LOG(WARNING) << "could not find opProp method listAuxiliaryStates.";
2274 success = false;
2275 } else {
2276 auto obj = env->CallObjectMethod(globalOpPropMap.at(key), midListAuxStates);
2277 if (obj != NULL) {
2278 jobjectArray jauxs = (jobjectArray)obj;
2279 int len = env->GetArrayLength(jauxs);
2280 *auxs = new char *[len+1];
2281 for (int i = 0; i < len; ++i) {
2282 jstring jaux = reinterpret_cast<jstring>(env->GetObjectArrayElement(jauxs, i));
2283 const char *aux = env->GetStringUTFChars(jaux, 0);
2284 (*auxs)[i] = const_cast<char *>(aux);
2285 env->DeleteLocalRef(jaux);
2286 }
2287 (*auxs)[len] = NULL;
2288 } else {
2289 (*auxs) = new char *[1];
2290 (*auxs)[0] = NULL;
2291 }
2292 }
2293 _jvm->DetachCurrentThread();
2294 }
2295 return success;
2296 };
2297
2298 // declare_backward_dependency callback
2299 auto opPropDeclareBkDep = [](const int *outGrad, const int *inData,
2300 const int *outData, int *numDeps, int **rdeps, void *state) {
2301 int success = true;
2302 std::string key(reinterpret_cast<char *>(state));
2303 if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2304 LOG(WARNING) << "CustomOpProp: " << key << " not found";
2305 success = false;
2306 } else {
2307 JNIEnv *env;
2308 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2309 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2310 jmethodID midDeclareBkDep = env->GetMethodID(
2311 opPropClass, "declareBackwardDependency", "([I[I[I)[I");
2312 if (NULL == midDeclareBkDep) {
2313 LOG(WARNING) << "could not find opProp method declareBackwardDependency.";
2314 success = false;
2315 } else {
2316 jmethodID midListOutputs = env->GetMethodID(
2317 opPropClass, "listOutputs", "()[Ljava/lang/String;");
2318 jobjectArray joutputs = (jobjectArray)(env->CallObjectMethod(
2319 globalOpPropMap.at(key), midListOutputs));
2320 int outLen = env->GetArrayLength(joutputs);
2321 jmethodID midListArguments = env->GetMethodID(
2322 opPropClass, "listArguments", "()[Ljava/lang/String;");
2323 jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
2324 globalOpPropMap.at(key), midListArguments));
2325 int intLen = env->GetArrayLength(jargs);
2326
2327 jintArray outGradArr = env->NewIntArray(outLen);
2328 env->SetIntArrayRegion(outGradArr, (jsize)0, (jsize)outLen, outGrad);
2329 jintArray inDataArr = env->NewIntArray(intLen);
2330 env->SetIntArrayRegion(inDataArr, (jsize)0, (jsize)intLen, inData);
2331 jintArray outDataArr = env->NewIntArray(outLen);
2332 env->SetIntArrayRegion(outDataArr, (jsize)0, (jsize)outLen, outData);
2333
2334 auto obj = env->CallObjectMethod(globalOpPropMap.at(key), midDeclareBkDep,
2335 outGradArr,
2336 inDataArr,
2337 outDataArr);
2338 jintArray jrdeps = (jintArray)obj;
2339 jint *rdepsArr = env->GetIntArrayElements(jrdeps, NULL);
2340
2341 *numDeps = env->GetArrayLength(jrdeps);
2342 *rdeps = new int[(* numDeps)];
2343 for (int i = 0 ; i < (*numDeps); ++i) {
2344 (*rdeps)[i] = rdepsArr[i];
2345 }
2346 env->DeleteLocalRef(outGradArr);
2347 env->DeleteLocalRef(inDataArr);
2348 env->DeleteLocalRef(outDataArr);
2349 env->ReleaseIntArrayElements(jrdeps, rdepsArr, 0);
2350 }
2351 _jvm->DetachCurrentThread();
2352 }
2353 return success;
2354 };
2355
2356 // infer_shape callback
2357 auto opPropInferShape = [](int numInput, int *ndims,
2358 unsigned **shapes, void *state) {
2359 int success = true;
2360 std::string key(reinterpret_cast<char *>(state));
2361 if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2362 LOG(WARNING) << "CustomOpProp: " << key << " not found";
2363 success = false;
2364 } else {
2365 JNIEnv *env;
2366 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2367 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2368 jmethodID midInferShape = env->GetMethodID(opPropClass, "inferShapeEntry", "(I[[I)[[I");
2369 if (NULL == midInferShape) {
2370 LOG(WARNING) << "could not find opProp method inferShapeEntry.";
2371 success = false;
2372 } else {
2373 jmethodID midListArguments = env->GetMethodID(
2374 opPropClass, "listArguments", "()[Ljava/lang/String;");
2375 jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
2376 globalOpPropMap.at(key), midListArguments));
2377 int intLen = env->GetArrayLength(jargs);
2378 jintArray *ts = new jintArray[intLen];
2379 auto tmp = env->NewIntArray(1);
2380 jclass arrayClass = env->GetObjectClass(tmp);
2381 env->DeleteLocalRef(tmp);
2382 jobjectArray tensorShapes = env->NewObjectArray(intLen, arrayClass, NULL);
2383 for (int i = 0; i < intLen; ++i) {
2384 ts[i] = env->NewIntArray(ndims[i]);
2385 env->SetIntArrayRegion(
2386 ts[i], (jsize)0, (jsize)ndims[i], reinterpret_cast<int *>(shapes[i]));
2387 env->SetObjectArrayElement(tensorShapes, i, (jobject)(ts[i]));
2388 }
2389 jobjectArray ret = (jobjectArray)(env->CallObjectMethod(
2390 globalOpPropMap.at(key), midInferShape,
2391 numInput,
2392 tensorShapes));
2393 for (int i = 0; i < numInput; ++i) {
2394 jintArray jarr = reinterpret_cast<jintArray>(env->GetObjectArrayElement(ret, i));
2395 int len = env->GetArrayLength(jarr);
2396 jint *arr = env->GetIntArrayElements(jarr, NULL);
2397 ndims[i] = len;
2398 shapes[i] = new unsigned[len];
2399 for (int j = 0; j < len; ++j) shapes[i][j] = (unsigned)(arr[j]);
2400 env->DeleteLocalRef(jarr);
2401 }
2402 for (int i = 0; i < intLen; ++i) {
2403 env->DeleteLocalRef(ts[i]);
2404 }
2405 delete[] ts;
2406 }
2407 _jvm->DetachCurrentThread();
2408 }
2409 return success;
2410 };
2411
2412 // infer_type callback
2413 auto opPropInferType = [](int numInput, int* types, void* state) {
2414 int success = true;
2415 std::string key(reinterpret_cast<char *>(state));
2416 if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2417 LOG(WARNING) << "CustomOpProp: " << key << " not found";
2418 success = false;
2419 } else {
2420 JNIEnv *env;
2421 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2422 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2423 jmethodID midInferType = env->GetMethodID(opPropClass, "inferTypeEntry", "(I[I)[I");
2424 if (NULL == midInferType) {
2425 LOG(WARNING) << "could not find opProp method inferTypeEntry.";
2426 success = false;
2427 } else {
2428 jmethodID midListArguments = env->GetMethodID(
2429 opPropClass, "listArguments", "()[Ljava/lang/String;");
2430 jobjectArray jargs = (jobjectArray)(env->CallObjectMethod(
2431 globalOpPropMap.at(key), midListArguments));
2432
2433 int intLen = env->GetArrayLength(jargs);
2434 jintArray ts = env->NewIntArray(intLen);
2435 int *tmp = new int[intLen];
2436 for (int i = 0; i < intLen; ++i) tmp[i] = types[i];
2437 env->SetIntArrayRegion(ts, (jsize)0, (jsize)intLen, tmp);
2438
2439 jintArray ret = (jintArray)(env->CallObjectMethod(
2440 globalOpPropMap.at(key), midInferType,
2441 numInput,
2442 ts));
2443 jint *arr = env->GetIntArrayElements(ret, NULL);
2444 for (int i = 0; i < numInput; ++i) {
2445 types[i] = static_cast<int>(arr[i]);
2446 }
2447
2448 delete[] tmp;
2449 env->ReleaseIntArrayElements(ret, arr, 0);
2450 env->DeleteLocalRef(ret);
2451 env->DeleteLocalRef(ts);
2452 }
2453 _jvm->DetachCurrentThread();
2454 }
2455 return success;
2456 };
2457
2458 // create_operator callback
2459 auto opPropCreateOp = [](const char *ctx, int numInputs,
2460 unsigned **shapes, int *ndims, int *dtypes, MXCallbackList *ret, void *state) {
2461 int success = true;
2462 std::string key(reinterpret_cast<char *>(state));
2463 if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2464 LOG(WARNING) << "CustomOpProp: " << key << " not found";
2465 success = false;
2466 } else {
2467 JNIEnv *env;
2468 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2469 jclass opPropClass = env->GetObjectClass(globalOpPropMap.at(key));
2470 jmethodID midCreateOp = env->GetMethodID(
2471 opPropClass, "createOperator", "(Ljava/lang/String;[[I[I)Lorg/apache/mxnet/CustomOp;");
2472 if (NULL == midCreateOp) {
2473 LOG(WARNING) << "could not find opProp method createOperator.";
2474 success = false;
2475 } else {
2476 jstring jctx = env->NewStringUTF(ctx);
2477 jintArray *ts = new jintArray[numInputs];
2478 auto tmp = env->NewIntArray(1);
2479 jclass arrayClass = env->GetObjectClass(tmp);
2480 env->DeleteLocalRef(tmp);
2481 jobjectArray inputShapes = env->NewObjectArray(numInputs, arrayClass, NULL);
2482 for (int i = 0; i < numInputs; ++i) {
2483 ts[i] = env->NewIntArray(ndims[i]);
2484 env->SetIntArrayRegion(
2485 ts[i], (jsize)0, (jsize)ndims[i], reinterpret_cast<int *>(shapes[i]));
2486 env->SetObjectArrayElement(inputShapes, i, (jobject)(ts[i]));
2487 }
2488 jintArray jdtypes = env->NewIntArray(numInputs);
2489 env->SetIntArrayRegion(jdtypes, (jsize)0, (jsize)numInputs, dtypes);
2490 // get operator
2491 jobject jOp = env->CallObjectMethod(globalOpPropMap.at(key), midCreateOp,
2492 jctx,
2493 inputShapes,
2494 jdtypes);
2495 env->DeleteLocalRef(jctx);
2496 for (int i = 0; i < numInputs; ++i) {
2497 env->DeleteLocalRef(ts[i]);
2498 }
2499 delete[] ts;
2500
2501 std::unique_lock<std::mutex> lock(mutex_op);
2502 globalOpMap.insert({ key, env->NewGlobalRef(jOp) });
2503 lock.unlock();
2504
2505 _jvm->DetachCurrentThread();
2506
2507 // forward callback
2508 auto forwardEntry = [](int size, void **ptrs, int *tags,
2509 const int *reqs, const int isTrain, void *state) {
2510 std::string key(reinterpret_cast<char *>(state));
2511 int success = true;
2512 if (globalOpMap.find(key) == globalOpMap.end()) {
2513 LOG(WARNING) << "op: " << key << " not found";
2514 success = false;
2515 } else {
2516 JNIEnv *env;
2517 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2518 jclass opClass = env->GetObjectClass(globalOpMap.at(key));
2519 jmethodID midForward = env->GetMethodID(opClass, "forwardEntry", "(I[J[I[IZ)Z");
2520 if (NULL == midForward) {
2521 LOG(WARNING) << "could not find op method forwardEntry.";
2522 success = false;
2523 } else {
2524 jintArray tagsArr = env->NewIntArray(size);
2525 env->SetIntArrayRegion(tagsArr, (jsize)0, (jsize)size, tags);
2526 int reqSize = 0;
2527 for (int i = 0; i < size; ++i) {
2528 if (tags[i] == 1) reqSize++;
2529 }
2530 jintArray reqsArr = env->NewIntArray(reqSize);
2531 env->SetIntArrayRegion(reqsArr, (jsize)0, (jsize)reqSize, reqs);
2532 jlongArray ptrsArr = env->NewLongArray(size);
2533 env->SetLongArrayRegion(
2534 ptrsArr, (jsize)0, (jsize)size, reinterpret_cast<jlong*>(ptrs));
2535 #if MXNET_USE_CUDA
2536 mxnet::NDArray* tmp = reinterpret_cast<mxnet::NDArray*>(ptrs[0]);
2537 if (tmp->ctx().dev_type == mxnet::Context::kGPU
2538 || tmp->ctx().dev_type == mxnet::Context::kCPUPinned) {
2539 CUDA_CALL(cudaSetDevice(tmp->ctx().dev_id));
2540 }
2541 #endif
2542 bool is_train = true;
2543 if (isTrain == 0) is_train = false;
2544 success = env->CallBooleanMethod(globalOpMap.at(key), midForward,
2545 size,
2546 ptrsArr,
2547 tagsArr,
2548 reqsArr,
2549 is_train);
2550 env->DeleteLocalRef(tagsArr);
2551 env->DeleteLocalRef(reqsArr);
2552 env->DeleteLocalRef(ptrsArr);
2553 }
2554 _jvm->DetachCurrentThread();
2555 }
2556 return success;
2557 };
2558
2559 // backward callback
2560 auto backwardEntry = [](int size, void **ptrs, int *tags,
2561 const int *reqs, const int isTrain, void *state) {
2562 std::string key(reinterpret_cast<char *>(state));
2563 int success = true;
2564 if (globalOpMap.find(key) == globalOpMap.end()) {
2565 LOG(WARNING) << "op: " << key << " not found";
2566 success = false;
2567 } else {
2568 JNIEnv *env;
2569 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2570 jclass opClass = env->GetObjectClass(globalOpMap.at(key));
2571 jmethodID midBackward = env->GetMethodID(opClass, "backwardEntry", "(I[J[I[IZ)Z");
2572 if (NULL == midBackward) {
2573 LOG(WARNING) << "could not find op method backwardEntry.";
2574 success = false;
2575 } else {
2576 jintArray tagsArr = env->NewIntArray(size);
2577 env->SetIntArrayRegion(tagsArr, (jsize)0, (jsize)size, tags);
2578
2579 int reqSize = 0;
2580 for (int i = 0; i < size; ++i) {
2581 if (tags[i] == 2) reqSize++;
2582 }
2583 jintArray reqsArr = env->NewIntArray(reqSize);
2584 env->SetIntArrayRegion(reqsArr, (jsize)0, (jsize)reqSize, reqs);
2585 jlongArray ptrsArr = env->NewLongArray(size);
2586 env->SetLongArrayRegion(
2587 ptrsArr, (jsize)0, (jsize)size, reinterpret_cast<jlong*>(ptrs));
2588 bool is_train = true;
2589 if (isTrain == 0) is_train = false;
2590 success = env->CallBooleanMethod(globalOpMap.at(key), midBackward,
2591 size,
2592 ptrsArr,
2593 tagsArr,
2594 reqsArr,
2595 is_train);
2596 env->DeleteLocalRef(tagsArr);
2597 env->DeleteLocalRef(reqsArr);
2598 env->DeleteLocalRef(ptrsArr);
2599 }
2600 _jvm->DetachCurrentThread();
2601 }
2602 return success;
2603 };
2604
2605 // del callback
2606 auto delEntry = [](void *state) {
2607 std::string key(reinterpret_cast<char *>(state));
2608 int success = true;
2609 std::unique_lock<std::mutex> lock(mutex_op);
2610 if (globalOpMap.find(key) == globalOpMap.end()) {
2611 LOG(WARNING) << "op: " << key << " not found";
2612 success = false;
2613 } else {
2614 JNIEnv *env;
2615 _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2616 env->DeleteGlobalRef(globalOpMap.at(key));
2617 _jvm->DetachCurrentThread();
2618 for (auto it = globalOpMap.begin(); it != globalOpMap.end(); ) {
2619 if (it->first == key) {
2620 it = globalOpMap.erase(it);
2621 } else {
2622 ++it;
2623 }
2624 }
2625 }
2626 lock.unlock();
2627 return success;
2628 };
2629
2630 // TODO(eric): Memory leak here. Refactor later and delete in delEntry
2631 ret->num_callbacks = 3;
2632 ret->callbacks = new MXGenericCallback[ret->num_callbacks];
2633 ret->callbacks[kCustomOpDelete] =
2634 reinterpret_cast<int(*)(void)>(static_cast<int(*)(void*)>(delEntry));
2635 ret->callbacks[kCustomOpForward] =
2636 reinterpret_cast<int(*)(void)>(
2637 static_cast<int(*)(int, void**, int*, const int*, const int, void*)>(
2638 forwardEntry));
2639 ret->callbacks[kCustomOpBackward] =
2640 reinterpret_cast<int(*)(void)>(
2641 static_cast<int(*)(int, void**, int*, const int*, const int, void*)>(
2642 backwardEntry));
2643 ret->contexts = new void*[ret->num_callbacks];
2644 ret->contexts[kCustomOpDelete] = state;
2645 ret->contexts[kCustomOpForward] = state;
2646 ret->contexts[kCustomOpBackward] = state;
2647 }
2648 }
2649 return success;
2650 };
2651
2652 // del callback
2653 auto opPropDel = [](void *state) {
2654 /*
2655 * This method seems to be called by the engine to clean up after multiple calls were made
2656 * to the creator lambda. The current creator function isn't allocating a new object but is
2657 * instead reinitializing the object which was created when register was called. This means
2658 * that there doesn't seem to be anything to clean up here (previous efforts were actually
2659 * deregistering the operator).
2660 */
2661 return 1;
2662 };
2663
2664 // TODO(eric): Memory leak. Missing infertype.
2665 ret->num_callbacks = 8;
2666 ret->callbacks = new MXGenericCallback[ret->num_callbacks];
2667 ret->callbacks[kCustomOpPropDelete] =
2668 reinterpret_cast<int(*)(void)>(
2669 static_cast<int(*)(void*)>(opPropDel));
2670 ret->callbacks[kCustomOpPropListArguments] =
2671 reinterpret_cast<int(*)(void)>(
2672 static_cast<int(*)(char***, void*)>(opPropListArgument));
2673 ret->callbacks[kCustomOpPropListOutputs] =
2674 reinterpret_cast<int(*)(void)>(
2675 static_cast<int(*)(char***, void*)>(opPropListOutputs));
2676 ret->callbacks[kCustomOpPropListAuxiliaryStates] =
2677 reinterpret_cast<int(*)(void)>(
2678 static_cast<int(*)(char***, void*)>(opPropListAuxStates));
2679 ret->callbacks[kCustomOpPropInferShape] =
2680 reinterpret_cast<int(*)(void)>(
2681 static_cast<int (*)(int, int*, unsigned**, void*)>(opPropInferShape));
2682 ret->callbacks[kCustomOpPropDeclareBackwardDependency] =
2683 reinterpret_cast<int(*)(void)>(
2684 static_cast<int(*)(const int*, const int*, const int*, int* num_deps, int**, void*)>(
2685 opPropDeclareBkDep));
2686 ret->callbacks[kCustomOpPropCreateOperator] =
2687 reinterpret_cast<int(*)(void)>(
2688 static_cast<int(*)(const char*, int, unsigned**, int*, int*, MXCallbackList*, void*)>(
2689 opPropCreateOp));
2690 ret->callbacks[kCustomOpPropInferType] =
2691 reinterpret_cast<int(*)(void)>(
2692 static_cast<int(*)(int, int*, void*)>(opPropInferType));
2693
2694 ret->contexts = new void*[ret->num_callbacks];
2695 ret->contexts[kCustomOpPropDelete] =
2696 reinterpret_cast<void *>(const_cast<char *>(opType));
2697 ret->contexts[kCustomOpPropListArguments] =
2698 reinterpret_cast<void *>(const_cast<char *>(opType));
2699 ret->contexts[kCustomOpPropListOutputs] =
2700 reinterpret_cast<void *>(const_cast<char *>(opType));
2701 ret->contexts[kCustomOpPropListAuxiliaryStates] =
2702 reinterpret_cast<void *>(const_cast<char *>(opType));
2703 ret->contexts[kCustomOpPropInferShape] =
2704 reinterpret_cast<void *>(const_cast<char *>(opType));
2705 ret->contexts[kCustomOpPropDeclareBackwardDependency] =
2706 reinterpret_cast<void *>(const_cast<char *>(opType));
2707 ret->contexts[kCustomOpPropCreateOperator] =
2708 reinterpret_cast<void *>(const_cast<char *>(opType));
2709 ret->contexts[kCustomOpPropInferType] =
2710 reinterpret_cast<void *>(const_cast<char *>(opType));
2711 return success;
2712 };
2713
2714 CustomOpPropCreator creator =
2715 static_cast<int(*)(const char*, const int, const char**, const char**, MXCallbackList*)>(
2716 creatorLambda);
2717 return MXCustomOpRegister(regName, creator);
2718 }
2719
2720 struct JNIString {
2721 JNIEnv *env_;
2722 jstring java_string_;
2723 const char *str_;
JNIStringJNIString2724 inline JNIString(JNIEnv *env, const jstring& java_string)
2725 : env_(env)
2726 , java_string_(java_string) {
2727 str_ = env_->GetStringUTFChars(java_string_, 0);
2728 }
~JNIStringJNIString2729 inline ~JNIString() {
2730 if (str_) {
2731 env_->ReleaseStringUTFChars(java_string_, str_);
2732 }
2733 }
operator ()JNIString2734 inline const char *operator ()() const {
2735 return str_;
2736 }
2737 };
2738
2739 struct JNIStringArray {
2740 std::vector<std::unique_ptr<JNIString>> jni_strings_;
2741 std::vector<const char *> strings_;
JNIStringArrayJNIStringArray2742 JNIStringArray(JNIEnv *env, const jobjectArray& stringArray) {
2743 const int count = env->GetArrayLength(stringArray);
2744 jni_strings_.reserve(count);
2745 strings_.reserve(count);
2746 for (int i = 0; i < count; ++i) {
2747 jstring string = static_cast<jstring>(env->GetObjectArrayElement(stringArray, i));
2748 jni_strings_.emplace_back(std::unique_ptr<JNIString>(new JNIString(env, string)));
2749 strings_.emplace_back((*jni_strings_.rbegin())->str_);
2750 }
2751 }
operator ()JNIStringArray2752 const char * const* operator ()() const { return &strings_[0]; }
2753 };
2754
Java_org_apache_mxnet_LibInfo_mxSetProfilerConfig(JNIEnv * env,jobject obj,jobjectArray keys,jobjectArray vals)2755 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetProfilerConfig
2756 (JNIEnv *env, jobject obj, jobjectArray keys, jobjectArray vals) {
2757 const int stringCount = env->GetArrayLength(keys);
2758 CHECK_EQ(stringCount, env->GetArrayLength(vals)) << "Key and value arrays must be the same size";
2759
2760 JNIStringArray the_keys(env, keys), the_vals(env, vals);
2761
2762 const int ret = MXSetProfilerConfig(stringCount, the_keys(), the_vals());
2763 return ret;
2764 }
2765
Java_org_apache_mxnet_LibInfo_mxSetProfilerState(JNIEnv * env,jobject obj,jint jstate)2766 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetProfilerState
2767 (JNIEnv *env, jobject obj, jint jstate) {
2768 return MXSetProfilerState(jstate);
2769 }
2770
Java_org_apache_mxnet_LibInfo_mxDumpProfile(JNIEnv * env,jobject obj,jint finished)2771 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile
2772 (JNIEnv *env, jobject obj, jint finished) {
2773 return MXDumpProfile(finished);
2774 }
2775
2776 // Numpy
Java_org_apache_mxnet_LibInfo_mxIsNumpyShape(JNIEnv * env,jobject obj,jobject compatibleRef)2777 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyShape
2778 (JNIEnv *env, jobject obj, jobject compatibleRef) {
2779 int isNumpyShape;
2780 int ret = MXIsNumpyShape(&isNumpyShape);
2781 SetIntField(env, compatibleRef, isNumpyShape);
2782 return ret;
2783 }
2784
Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape(JNIEnv * env,jobject obj,jint isNpComp,jobject prevRef)2785 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyShape
2786 (JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) {
2787 int prev;
2788 int ret = MXSetIsNumpyShape(isNpComp, &prev);
2789 SetIntField(env, prevRef, prev);
2790 return ret;
2791 }
2792