1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file jni_helper_func.h
21 * \brief Helper functions for operating JVM objects
22 */
23 #include <jni.h>
24
25 #ifndef TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
26 #define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
27
28 // Helper functions for RefXXX getter & setter
getLongField(JNIEnv * env,jobject obj)29 jlong getLongField(JNIEnv* env, jobject obj) {
30 jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong");
31 jfieldID refFid = env->GetFieldID(refClass, "value", "J");
32 jlong ret = env->GetLongField(obj, refFid);
33 env->DeleteLocalRef(refClass);
34 return ret;
35 }
36
getIntField(JNIEnv * env,jobject obj)37 jint getIntField(JNIEnv* env, jobject obj) {
38 jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt");
39 jfieldID refFid = env->GetFieldID(refClass, "value", "I");
40 jint ret = env->GetIntField(obj, refFid);
41 env->DeleteLocalRef(refClass);
42 return ret;
43 }
44
setIntField(JNIEnv * env,jobject obj,jint value)45 void setIntField(JNIEnv* env, jobject obj, jint value) {
46 jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt");
47 jfieldID refFid = env->GetFieldID(refClass, "value", "I");
48 env->SetIntField(obj, refFid, value);
49 env->DeleteLocalRef(refClass);
50 }
51
setLongField(JNIEnv * env,jobject obj,jlong value)52 void setLongField(JNIEnv* env, jobject obj, jlong value) {
53 jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong");
54 jfieldID refFid = env->GetFieldID(refClass, "value", "J");
55 env->SetLongField(obj, refFid, value);
56 env->DeleteLocalRef(refClass);
57 }
58
setStringField(JNIEnv * env,jobject obj,const char * value)59 void setStringField(JNIEnv* env, jobject obj, const char* value) {
60 jclass refClass = env->FindClass("org/apache/tvm/Base$RefString");
61 jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;");
62 env->SetObjectField(obj, refFid, env->NewStringUTF(value));
63 env->DeleteLocalRef(refClass);
64 }
65
66 // Helper functions for TVMValue
67 jlong getTVMValueLongField(JNIEnv* env, jobject obj,
68 const char* clsname = "org/apache/tvm/TVMValueLong") {
69 jclass cls = env->FindClass(clsname);
70 jfieldID fid = env->GetFieldID(cls, "value", "J");
71 jlong ret = env->GetLongField(obj, fid);
72 env->DeleteLocalRef(cls);
73 return ret;
74 }
75
getTVMValueDoubleField(JNIEnv * env,jobject obj)76 jdouble getTVMValueDoubleField(JNIEnv* env, jobject obj) {
77 jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble");
78 jfieldID fid = env->GetFieldID(cls, "value", "D");
79 jdouble ret = env->GetDoubleField(obj, fid);
80 env->DeleteLocalRef(cls);
81 return ret;
82 }
83
getTVMValueStringField(JNIEnv * env,jobject obj)84 jstring getTVMValueStringField(JNIEnv* env, jobject obj) {
85 jclass cls = env->FindClass("org/apache/tvm/TVMValueString");
86 jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;");
87 jstring ret = static_cast<jstring>(env->GetObjectField(obj, fid));
88 env->DeleteLocalRef(cls);
89 return ret;
90 }
91
newTVMValueHandle(JNIEnv * env,jlong value)92 jobject newTVMValueHandle(JNIEnv* env, jlong value) {
93 jclass cls = env->FindClass("org/apache/tvm/TVMValueHandle");
94 jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
95 jobject object = env->NewObject(cls, constructor, value);
96 env->DeleteLocalRef(cls);
97 return object;
98 }
99
newTVMValueLong(JNIEnv * env,jlong value)100 jobject newTVMValueLong(JNIEnv* env, jlong value) {
101 jclass cls = env->FindClass("org/apache/tvm/TVMValueLong");
102 jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
103 jobject object = env->NewObject(cls, constructor, value);
104 env->DeleteLocalRef(cls);
105 return object;
106 }
107
newTVMValueDouble(JNIEnv * env,jdouble value)108 jobject newTVMValueDouble(JNIEnv* env, jdouble value) {
109 jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble");
110 jmethodID constructor = env->GetMethodID(cls, "<init>", "(D)V");
111 jobject object = env->NewObject(cls, constructor, value);
112 env->DeleteLocalRef(cls);
113 return object;
114 }
115
newTVMValueString(JNIEnv * env,const char * value)116 jobject newTVMValueString(JNIEnv* env, const char* value) {
117 jstring jvalue = env->NewStringUTF(value);
118 jclass cls = env->FindClass("org/apache/tvm/TVMValueString");
119 jmethodID constructor = env->GetMethodID(cls, "<init>", "(Ljava/lang/String;)V");
120 jobject object = env->NewObject(cls, constructor, jvalue);
121 env->DeleteLocalRef(cls);
122 env->DeleteLocalRef(jvalue);
123 return object;
124 }
125
newTVMValueBytes(JNIEnv * env,const TVMByteArray * arr)126 jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) {
127 jbyteArray jarr = env->NewByteArray(arr->size);
128 env->SetByteArrayRegion(jarr, 0, arr->size,
129 reinterpret_cast<jbyte*>(const_cast<char*>(arr->data)));
130 jclass cls = env->FindClass("org/apache/tvm/TVMValueBytes");
131 jmethodID constructor = env->GetMethodID(cls, "<init>", "([B)V");
132 jobject object = env->NewObject(cls, constructor, jarr);
133 env->DeleteLocalRef(cls);
134 env->DeleteLocalRef(jarr);
135 return object;
136 }
137
newModule(JNIEnv * env,jlong value)138 jobject newModule(JNIEnv* env, jlong value) {
139 jclass cls = env->FindClass("org/apache/tvm/Module");
140 jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
141 jobject object = env->NewObject(cls, constructor, value);
142 env->DeleteLocalRef(cls);
143 return object;
144 }
145
newFunction(JNIEnv * env,jlong value)146 jobject newFunction(JNIEnv* env, jlong value) {
147 jclass cls = env->FindClass("org/apache/tvm/Function");
148 jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
149 jobject object = env->NewObject(cls, constructor, value);
150 env->DeleteLocalRef(cls);
151 return object;
152 }
153
newNDArray(JNIEnv * env,jlong handle,jboolean isview)154 jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) {
155 jclass cls = env->FindClass("org/apache/tvm/NDArrayBase");
156 jmethodID constructor = env->GetMethodID(cls, "<init>", "(JZ)V");
157 jobject object = env->NewObject(cls, constructor, handle, isview);
158 env->DeleteLocalRef(cls);
159 return object;
160 }
161
newObject(JNIEnv * env,const char * clsname)162 jobject newObject(JNIEnv* env, const char* clsname) {
163 jclass cls = env->FindClass(clsname);
164 jmethodID constructor = env->GetMethodID(cls, "<init>", "()V");
165 jobject object = env->NewObject(cls, constructor);
166 env->DeleteLocalRef(cls);
167 return object;
168 }
169
fromJavaDType(JNIEnv * env,jobject jdtype,DLDataType * dtype)170 void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) {
171 jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType");
172 dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I")));
173 dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I")));
174 dtype->lanes = (uint16_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "lanes", "I")));
175 env->DeleteLocalRef(tvmTypeClass);
176 }
177
fromJavaContext(JNIEnv * env,jobject jctx,TVMContext * ctx)178 void fromJavaContext(JNIEnv* env, jobject jctx, TVMContext* ctx) {
179 jclass tvmContextClass = env->FindClass("org/apache/tvm/TVMContext");
180 ctx->device_type = static_cast<DLDeviceType>(
181 env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceType", "I")));
182 ctx->device_id =
183 static_cast<int>(env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceId", "I")));
184 env->DeleteLocalRef(tvmContextClass);
185 }
186
tvmRetValueToJava(JNIEnv * env,TVMValue value,int tcode)187 jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) {
188 switch (tcode) {
189 case kDLUInt:
190 case kDLInt:
191 return newTVMValueLong(env, static_cast<jlong>(value.v_int64));
192 case kDLFloat:
193 return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64));
194 case kTVMOpaqueHandle:
195 return newTVMValueHandle(env, reinterpret_cast<jlong>(value.v_handle));
196 case kTVMModuleHandle:
197 return newModule(env, reinterpret_cast<jlong>(value.v_handle));
198 case kTVMPackedFuncHandle:
199 return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
200 case kTVMDLTensorHandle:
201 return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true);
202 case kTVMNDArrayHandle:
203 return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false);
204 case kTVMStr:
205 return newTVMValueString(env, value.v_str);
206 case kTVMBytes:
207 return newTVMValueBytes(env, reinterpret_cast<TVMByteArray*>(value.v_handle));
208 case kTVMNullptr:
209 return newObject(env, "org/apache/tvm/TVMValueNull");
210 default:
211 LOG(FATAL) << "Do NOT know how to handle return type code " << tcode;
212 }
213 return NULL;
214 }
215
216 #endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
217