1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*!
20  * \file jni_helper_func.h
21  * \brief Helper functions for operating JVM objects
22  */
23 #include <jni.h>
24 
25 #ifndef TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
26 #define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
27 
28 // Helper functions for RefXXX getter & setter
getLongField(JNIEnv * env,jobject obj)29 jlong getLongField(JNIEnv* env, jobject obj) {
30   jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong");
31   jfieldID refFid = env->GetFieldID(refClass, "value", "J");
32   jlong ret = env->GetLongField(obj, refFid);
33   env->DeleteLocalRef(refClass);
34   return ret;
35 }
36 
getIntField(JNIEnv * env,jobject obj)37 jint getIntField(JNIEnv* env, jobject obj) {
38   jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt");
39   jfieldID refFid = env->GetFieldID(refClass, "value", "I");
40   jint ret = env->GetIntField(obj, refFid);
41   env->DeleteLocalRef(refClass);
42   return ret;
43 }
44 
setIntField(JNIEnv * env,jobject obj,jint value)45 void setIntField(JNIEnv* env, jobject obj, jint value) {
46   jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt");
47   jfieldID refFid = env->GetFieldID(refClass, "value", "I");
48   env->SetIntField(obj, refFid, value);
49   env->DeleteLocalRef(refClass);
50 }
51 
setLongField(JNIEnv * env,jobject obj,jlong value)52 void setLongField(JNIEnv* env, jobject obj, jlong value) {
53   jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong");
54   jfieldID refFid = env->GetFieldID(refClass, "value", "J");
55   env->SetLongField(obj, refFid, value);
56   env->DeleteLocalRef(refClass);
57 }
58 
setStringField(JNIEnv * env,jobject obj,const char * value)59 void setStringField(JNIEnv* env, jobject obj, const char* value) {
60   jclass refClass = env->FindClass("org/apache/tvm/Base$RefString");
61   jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;");
62   env->SetObjectField(obj, refFid, env->NewStringUTF(value));
63   env->DeleteLocalRef(refClass);
64 }
65 
66 // Helper functions for TVMValue
67 jlong getTVMValueLongField(JNIEnv* env, jobject obj,
68                            const char* clsname = "org/apache/tvm/TVMValueLong") {
69   jclass cls = env->FindClass(clsname);
70   jfieldID fid = env->GetFieldID(cls, "value", "J");
71   jlong ret = env->GetLongField(obj, fid);
72   env->DeleteLocalRef(cls);
73   return ret;
74 }
75 
getTVMValueDoubleField(JNIEnv * env,jobject obj)76 jdouble getTVMValueDoubleField(JNIEnv* env, jobject obj) {
77   jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble");
78   jfieldID fid = env->GetFieldID(cls, "value", "D");
79   jdouble ret = env->GetDoubleField(obj, fid);
80   env->DeleteLocalRef(cls);
81   return ret;
82 }
83 
getTVMValueStringField(JNIEnv * env,jobject obj)84 jstring getTVMValueStringField(JNIEnv* env, jobject obj) {
85   jclass cls = env->FindClass("org/apache/tvm/TVMValueString");
86   jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;");
87   jstring ret = static_cast<jstring>(env->GetObjectField(obj, fid));
88   env->DeleteLocalRef(cls);
89   return ret;
90 }
91 
newTVMValueHandle(JNIEnv * env,jlong value)92 jobject newTVMValueHandle(JNIEnv* env, jlong value) {
93   jclass cls = env->FindClass("org/apache/tvm/TVMValueHandle");
94   jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
95   jobject object = env->NewObject(cls, constructor, value);
96   env->DeleteLocalRef(cls);
97   return object;
98 }
99 
newTVMValueLong(JNIEnv * env,jlong value)100 jobject newTVMValueLong(JNIEnv* env, jlong value) {
101   jclass cls = env->FindClass("org/apache/tvm/TVMValueLong");
102   jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
103   jobject object = env->NewObject(cls, constructor, value);
104   env->DeleteLocalRef(cls);
105   return object;
106 }
107 
newTVMValueDouble(JNIEnv * env,jdouble value)108 jobject newTVMValueDouble(JNIEnv* env, jdouble value) {
109   jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble");
110   jmethodID constructor = env->GetMethodID(cls, "<init>", "(D)V");
111   jobject object = env->NewObject(cls, constructor, value);
112   env->DeleteLocalRef(cls);
113   return object;
114 }
115 
newTVMValueString(JNIEnv * env,const char * value)116 jobject newTVMValueString(JNIEnv* env, const char* value) {
117   jstring jvalue = env->NewStringUTF(value);
118   jclass cls = env->FindClass("org/apache/tvm/TVMValueString");
119   jmethodID constructor = env->GetMethodID(cls, "<init>", "(Ljava/lang/String;)V");
120   jobject object = env->NewObject(cls, constructor, jvalue);
121   env->DeleteLocalRef(cls);
122   env->DeleteLocalRef(jvalue);
123   return object;
124 }
125 
newTVMValueBytes(JNIEnv * env,const TVMByteArray * arr)126 jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) {
127   jbyteArray jarr = env->NewByteArray(arr->size);
128   env->SetByteArrayRegion(jarr, 0, arr->size,
129                           reinterpret_cast<jbyte*>(const_cast<char*>(arr->data)));
130   jclass cls = env->FindClass("org/apache/tvm/TVMValueBytes");
131   jmethodID constructor = env->GetMethodID(cls, "<init>", "([B)V");
132   jobject object = env->NewObject(cls, constructor, jarr);
133   env->DeleteLocalRef(cls);
134   env->DeleteLocalRef(jarr);
135   return object;
136 }
137 
newModule(JNIEnv * env,jlong value)138 jobject newModule(JNIEnv* env, jlong value) {
139   jclass cls = env->FindClass("org/apache/tvm/Module");
140   jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
141   jobject object = env->NewObject(cls, constructor, value);
142   env->DeleteLocalRef(cls);
143   return object;
144 }
145 
newFunction(JNIEnv * env,jlong value)146 jobject newFunction(JNIEnv* env, jlong value) {
147   jclass cls = env->FindClass("org/apache/tvm/Function");
148   jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
149   jobject object = env->NewObject(cls, constructor, value);
150   env->DeleteLocalRef(cls);
151   return object;
152 }
153 
newNDArray(JNIEnv * env,jlong handle,jboolean isview)154 jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) {
155   jclass cls = env->FindClass("org/apache/tvm/NDArrayBase");
156   jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
157   jobject object = env->NewObject(cls, constructor, handle, isview);
158   env->DeleteLocalRef(cls);
159   return object;
160 }
161 
newObject(JNIEnv * env,const char * clsname)162 jobject newObject(JNIEnv* env, const char* clsname) {
163   jclass cls = env->FindClass(clsname);
164   jmethodID constructor = env->GetMethodID(cls, "<init>", "()V");
165   jobject object = env->NewObject(cls, constructor);
166   env->DeleteLocalRef(cls);
167   return object;
168 }
169 
fromJavaDType(JNIEnv * env,jobject jdtype,DLDataType * dtype)170 void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) {
171   jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType");
172   dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I")));
173   dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I")));
174   dtype->lanes = (uint16_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "lanes", "I")));
175   env->DeleteLocalRef(tvmTypeClass);
176 }
177 
fromJavaContext(JNIEnv * env,jobject jctx,TVMContext * ctx)178 void fromJavaContext(JNIEnv* env, jobject jctx, TVMContext* ctx) {
179   jclass tvmContextClass = env->FindClass("org/apache/tvm/TVMContext");
180   ctx->device_type = static_cast<DLDeviceType>(
181       env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceType", "I")));
182   ctx->device_id =
183       static_cast<int>(env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceId", "I")));
184   env->DeleteLocalRef(tvmContextClass);
185 }
186 
tvmRetValueToJava(JNIEnv * env,TVMValue value,int tcode)187 jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) {
188   switch (tcode) {
189     case kDLUInt:
190     case kDLInt:
191       return newTVMValueLong(env, static_cast<jlong>(value.v_int64));
192     case kDLFloat:
193       return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64));
194     case kTVMOpaqueHandle:
195       return newTVMValueHandle(env, reinterpret_cast<jlong>(value.v_handle));
196     case kTVMModuleHandle:
197       return newModule(env, reinterpret_cast<jlong>(value.v_handle));
198     case kTVMPackedFuncHandle:
199       return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
200     case kTVMDLTensorHandle:
201       return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true);
202     case kTVMNDArrayHandle:
203       return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false);
204     case kTVMStr:
205       return newTVMValueString(env, value.v_str);
206     case kTVMBytes:
207       return newTVMValueBytes(env, reinterpret_cast<TVMByteArray*>(value.v_handle));
208     case kTVMNullptr:
209       return newObject(env, "org/apache/tvm/TVMValueNull");
210     default:
211       LOG(FATAL) << "Do NOT know how to handle return type code " << tcode;
212   }
213   return NULL;
214 }
215 
216 #endif  // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
217