1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <google/protobuf/io/coded_stream.h>
19
20 #include <map>
21 #include <memory>
22 #include <mutex>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include <arrow/builder.h>
29 #include <arrow/record_batch.h>
30 #include <arrow/type.h>
31
32 #include "Types.pb.h"
33 #include "gandiva/configuration.h"
34 #include "gandiva/decimal_scalar.h"
35 #include "gandiva/filter.h"
36 #include "gandiva/jni/config_holder.h"
37 #include "gandiva/jni/env_helper.h"
38 #include "gandiva/jni/id_to_module_map.h"
39 #include "gandiva/jni/module_holder.h"
40 #include "gandiva/projector.h"
41 #include "gandiva/selection_vector.h"
42 #include "gandiva/tree_expr_builder.h"
43 #include "jni/org_apache_arrow_gandiva_evaluator_JniWrapper.h"
44
45 using gandiva::ConditionPtr;
46 using gandiva::DataTypePtr;
47 using gandiva::ExpressionPtr;
48 using gandiva::ExpressionVector;
49 using gandiva::FieldPtr;
50 using gandiva::FieldVector;
51 using gandiva::Filter;
52 using gandiva::NodePtr;
53 using gandiva::NodeVector;
54 using gandiva::Projector;
55 using gandiva::SchemaPtr;
56 using gandiva::Status;
57 using gandiva::TreeExprBuilder;
58
59 using gandiva::ArrayDataVector;
60 using gandiva::ConfigHolder;
61 using gandiva::Configuration;
62 using gandiva::ConfigurationBuilder;
63 using gandiva::FilterHolder;
64 using gandiva::ProjectorHolder;
65
66 // forward declarations
67 NodePtr ProtoTypeToNode(const types::TreeNode& node);
68
69 static jint JNI_VERSION = JNI_VERSION_1_6;
70
71 // extern refs - initialized for other modules.
72 jclass configuration_builder_class_;
73
74 // refs for self.
75 static jclass gandiva_exception_;
76 static jclass vector_expander_class_;
77 static jclass vector_expander_ret_class_;
78 static jmethodID vector_expander_method_;
79 static jfieldID vector_expander_ret_address_;
80 static jfieldID vector_expander_ret_capacity_;
81
82 // module maps
83 gandiva::IdToModuleMap<std::shared_ptr<ProjectorHolder>> projector_modules_;
84 gandiva::IdToModuleMap<std::shared_ptr<FilterHolder>> filter_modules_;
85
JNI_OnLoad(JavaVM * vm,void * reserved)86 jint JNI_OnLoad(JavaVM* vm, void* reserved) {
87 JNIEnv* env;
88 if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
89 return JNI_ERR;
90 }
91 jclass local_configuration_builder_class_ =
92 env->FindClass("org/apache/arrow/gandiva/evaluator/ConfigurationBuilder");
93 configuration_builder_class_ =
94 (jclass)env->NewGlobalRef(local_configuration_builder_class_);
95 env->DeleteLocalRef(local_configuration_builder_class_);
96
97 jclass localExceptionClass =
98 env->FindClass("org/apache/arrow/gandiva/exceptions/GandivaException");
99 gandiva_exception_ = (jclass)env->NewGlobalRef(localExceptionClass);
100 env->ExceptionDescribe();
101 env->DeleteLocalRef(localExceptionClass);
102
103 jclass local_expander_class =
104 env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander");
105 vector_expander_class_ = (jclass)env->NewGlobalRef(local_expander_class);
106 env->DeleteLocalRef(local_expander_class);
107
108 vector_expander_method_ = env->GetMethodID(
109 vector_expander_class_, "expandOutputVectorAtIndex",
110 "(IJ)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;");
111
112 jclass local_expander_ret_class =
113 env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult");
114 vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class);
115 env->DeleteLocalRef(local_expander_ret_class);
116
117 vector_expander_ret_address_ =
118 env->GetFieldID(vector_expander_ret_class_, "address", "J");
119 vector_expander_ret_capacity_ =
120 env->GetFieldID(vector_expander_ret_class_, "capacity", "J");
121 return JNI_VERSION;
122 }
123
JNI_OnUnload(JavaVM * vm,void * reserved)124 void JNI_OnUnload(JavaVM* vm, void* reserved) {
125 JNIEnv* env;
126 vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);
127 env->DeleteGlobalRef(configuration_builder_class_);
128 env->DeleteGlobalRef(gandiva_exception_);
129 env->DeleteGlobalRef(vector_expander_class_);
130 env->DeleteGlobalRef(vector_expander_ret_class_);
131 }
132
ProtoTypeToTime32(const types::ExtGandivaType & ext_type)133 DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) {
134 switch (ext_type.timeunit()) {
135 case types::SEC:
136 return arrow::time32(arrow::TimeUnit::SECOND);
137 case types::MILLISEC:
138 return arrow::time32(arrow::TimeUnit::MILLI);
139 default:
140 std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time32\n";
141 return nullptr;
142 }
143 }
144
ProtoTypeToTime64(const types::ExtGandivaType & ext_type)145 DataTypePtr ProtoTypeToTime64(const types::ExtGandivaType& ext_type) {
146 switch (ext_type.timeunit()) {
147 case types::MICROSEC:
148 return arrow::time64(arrow::TimeUnit::MICRO);
149 case types::NANOSEC:
150 return arrow::time64(arrow::TimeUnit::NANO);
151 default:
152 std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time64\n";
153 return nullptr;
154 }
155 }
156
ProtoTypeToTimestamp(const types::ExtGandivaType & ext_type)157 DataTypePtr ProtoTypeToTimestamp(const types::ExtGandivaType& ext_type) {
158 switch (ext_type.timeunit()) {
159 case types::SEC:
160 return arrow::timestamp(arrow::TimeUnit::SECOND);
161 case types::MILLISEC:
162 return arrow::timestamp(arrow::TimeUnit::MILLI);
163 case types::MICROSEC:
164 return arrow::timestamp(arrow::TimeUnit::MICRO);
165 case types::NANOSEC:
166 return arrow::timestamp(arrow::TimeUnit::NANO);
167 default:
168 std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for timestamp\n";
169 return nullptr;
170 }
171 }
172
ProtoTypeToInterval(const types::ExtGandivaType & ext_type)173 DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) {
174 switch (ext_type.intervaltype()) {
175 case types::YEAR_MONTH:
176 return arrow::month_interval();
177 case types::DAY_TIME:
178 return arrow::day_time_interval();
179 default:
180 std::cerr << "Unknown interval type: " << ext_type.intervaltype() << "\n";
181 return nullptr;
182 }
183 }
184
ProtoTypeToDataType(const types::ExtGandivaType & ext_type)185 DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) {
186 switch (ext_type.type()) {
187 case types::NONE:
188 return arrow::null();
189 case types::BOOL:
190 return arrow::boolean();
191 case types::UINT8:
192 return arrow::uint8();
193 case types::INT8:
194 return arrow::int8();
195 case types::UINT16:
196 return arrow::uint16();
197 case types::INT16:
198 return arrow::int16();
199 case types::UINT32:
200 return arrow::uint32();
201 case types::INT32:
202 return arrow::int32();
203 case types::UINT64:
204 return arrow::uint64();
205 case types::INT64:
206 return arrow::int64();
207 case types::HALF_FLOAT:
208 return arrow::float16();
209 case types::FLOAT:
210 return arrow::float32();
211 case types::DOUBLE:
212 return arrow::float64();
213 case types::UTF8:
214 return arrow::utf8();
215 case types::BINARY:
216 return arrow::binary();
217 case types::DATE32:
218 return arrow::date32();
219 case types::DATE64:
220 return arrow::date64();
221 case types::DECIMAL:
222 // TODO: error handling
223 return arrow::decimal(ext_type.precision(), ext_type.scale());
224 case types::TIME32:
225 return ProtoTypeToTime32(ext_type);
226 case types::TIME64:
227 return ProtoTypeToTime64(ext_type);
228 case types::TIMESTAMP:
229 return ProtoTypeToTimestamp(ext_type);
230 case types::INTERVAL:
231 return ProtoTypeToInterval(ext_type);
232 case types::FIXED_SIZE_BINARY:
233 case types::LIST:
234 case types::STRUCT:
235 case types::UNION:
236 case types::DICTIONARY:
237 case types::MAP:
238 std::cerr << "Unhandled data type: " << ext_type.type() << "\n";
239 return nullptr;
240
241 default:
242 std::cerr << "Unknown data type: " << ext_type.type() << "\n";
243 return nullptr;
244 }
245 }
246
ProtoTypeToField(const types::Field & f)247 FieldPtr ProtoTypeToField(const types::Field& f) {
248 const std::string& name = f.name();
249 DataTypePtr type = ProtoTypeToDataType(f.type());
250 bool nullable = true;
251 if (f.has_nullable()) {
252 nullable = f.nullable();
253 }
254
255 return field(name, type, nullable);
256 }
257
ProtoTypeToFieldNode(const types::FieldNode & node)258 NodePtr ProtoTypeToFieldNode(const types::FieldNode& node) {
259 FieldPtr field_ptr = ProtoTypeToField(node.field());
260 if (field_ptr == nullptr) {
261 std::cerr << "Unable to create field node from protobuf\n";
262 return nullptr;
263 }
264
265 return TreeExprBuilder::MakeField(field_ptr);
266 }
267
ProtoTypeToFnNode(const types::FunctionNode & node)268 NodePtr ProtoTypeToFnNode(const types::FunctionNode& node) {
269 const std::string& name = node.functionname();
270 NodeVector children;
271
272 for (int i = 0; i < node.inargs_size(); i++) {
273 const types::TreeNode& arg = node.inargs(i);
274
275 NodePtr n = ProtoTypeToNode(arg);
276 if (n == nullptr) {
277 std::cerr << "Unable to create argument for function: " << name << "\n";
278 return nullptr;
279 }
280
281 children.push_back(n);
282 }
283
284 DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
285 if (return_type == nullptr) {
286 std::cerr << "Unknown return type for function: " << name << "\n";
287 return nullptr;
288 }
289
290 return TreeExprBuilder::MakeFunction(name, children, return_type);
291 }
292
ProtoTypeToIfNode(const types::IfNode & node)293 NodePtr ProtoTypeToIfNode(const types::IfNode& node) {
294 NodePtr cond = ProtoTypeToNode(node.cond());
295 if (cond == nullptr) {
296 std::cerr << "Unable to create cond node for if node\n";
297 return nullptr;
298 }
299
300 NodePtr then_node = ProtoTypeToNode(node.thennode());
301 if (then_node == nullptr) {
302 std::cerr << "Unable to create then node for if node\n";
303 return nullptr;
304 }
305
306 NodePtr else_node = ProtoTypeToNode(node.elsenode());
307 if (else_node == nullptr) {
308 std::cerr << "Unable to create else node for if node\n";
309 return nullptr;
310 }
311
312 DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
313 if (return_type == nullptr) {
314 std::cerr << "Unknown return type for if node\n";
315 return nullptr;
316 }
317
318 return TreeExprBuilder::MakeIf(cond, then_node, else_node, return_type);
319 }
320
ProtoTypeToAndNode(const types::AndNode & node)321 NodePtr ProtoTypeToAndNode(const types::AndNode& node) {
322 NodeVector children;
323
324 for (int i = 0; i < node.args_size(); i++) {
325 const types::TreeNode& arg = node.args(i);
326
327 NodePtr n = ProtoTypeToNode(arg);
328 if (n == nullptr) {
329 std::cerr << "Unable to create argument for boolean and\n";
330 return nullptr;
331 }
332 children.push_back(n);
333 }
334 return TreeExprBuilder::MakeAnd(children);
335 }
336
ProtoTypeToOrNode(const types::OrNode & node)337 NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
338 NodeVector children;
339
340 for (int i = 0; i < node.args_size(); i++) {
341 const types::TreeNode& arg = node.args(i);
342
343 NodePtr n = ProtoTypeToNode(arg);
344 if (n == nullptr) {
345 std::cerr << "Unable to create argument for boolean or\n";
346 return nullptr;
347 }
348 children.push_back(n);
349 }
350 return TreeExprBuilder::MakeOr(children);
351 }
352
ProtoTypeToInNode(const types::InNode & node)353 NodePtr ProtoTypeToInNode(const types::InNode& node) {
354 NodePtr field = ProtoTypeToNode(node.node());
355
356 if (node.has_intvalues()) {
357 std::unordered_set<int32_t> int_values;
358 for (int i = 0; i < node.intvalues().intvalues_size(); i++) {
359 int_values.insert(node.intvalues().intvalues(i).value());
360 }
361 return TreeExprBuilder::MakeInExpressionInt32(field, int_values);
362 }
363
364 if (node.has_longvalues()) {
365 std::unordered_set<int64_t> long_values;
366 for (int i = 0; i < node.longvalues().longvalues_size(); i++) {
367 long_values.insert(node.longvalues().longvalues(i).value());
368 }
369 return TreeExprBuilder::MakeInExpressionInt64(field, long_values);
370 }
371
372 if (node.has_decimalvalues()) {
373 std::unordered_set<gandiva::DecimalScalar128> decimal_values;
374 for (int i = 0; i < node.decimalvalues().decimalvalues_size(); i++) {
375 decimal_values.insert(
376 gandiva::DecimalScalar128(node.decimalvalues().decimalvalues(i).value(),
377 node.decimalvalues().decimalvalues(i).precision(),
378 node.decimalvalues().decimalvalues(i).scale()));
379 }
380 return TreeExprBuilder::MakeInExpressionDecimal(field, decimal_values);
381 }
382
383 if (node.has_floatvalues()) {
384 std::unordered_set<float> float_values;
385 for (int i = 0; i < node.floatvalues().floatvalues_size(); i++) {
386 float_values.insert(node.floatvalues().floatvalues(i).value());
387 }
388 return TreeExprBuilder::MakeInExpressionFloat(field, float_values);
389 }
390
391 if (node.has_doublevalues()) {
392 std::unordered_set<double> double_values;
393 for (int i = 0; i < node.doublevalues().doublevalues_size(); i++) {
394 double_values.insert(node.doublevalues().doublevalues(i).value());
395 }
396 return TreeExprBuilder::MakeInExpressionDouble(field, double_values);
397 }
398
399 if (node.has_stringvalues()) {
400 std::unordered_set<std::string> stringvalues;
401 for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
402 stringvalues.insert(node.stringvalues().stringvalues(i).value());
403 }
404 return TreeExprBuilder::MakeInExpressionString(field, stringvalues);
405 }
406
407 if (node.has_binaryvalues()) {
408 std::unordered_set<std::string> stringvalues;
409 for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) {
410 stringvalues.insert(node.binaryvalues().binaryvalues(i).value());
411 }
412 return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues);
413 }
414 // not supported yet.
415 std::cerr << "Unknown constant type for in expression.\n";
416 return nullptr;
417 }
418
ProtoTypeToNullNode(const types::NullNode & node)419 NodePtr ProtoTypeToNullNode(const types::NullNode& node) {
420 DataTypePtr data_type = ProtoTypeToDataType(node.type());
421 if (data_type == nullptr) {
422 std::cerr << "Unknown type " << data_type->ToString() << " for null node\n";
423 return nullptr;
424 }
425
426 return TreeExprBuilder::MakeNull(data_type);
427 }
428
ProtoTypeToNode(const types::TreeNode & node)429 NodePtr ProtoTypeToNode(const types::TreeNode& node) {
430 if (node.has_fieldnode()) {
431 return ProtoTypeToFieldNode(node.fieldnode());
432 }
433
434 if (node.has_fnnode()) {
435 return ProtoTypeToFnNode(node.fnnode());
436 }
437
438 if (node.has_ifnode()) {
439 return ProtoTypeToIfNode(node.ifnode());
440 }
441
442 if (node.has_andnode()) {
443 return ProtoTypeToAndNode(node.andnode());
444 }
445
446 if (node.has_ornode()) {
447 return ProtoTypeToOrNode(node.ornode());
448 }
449
450 if (node.has_innode()) {
451 return ProtoTypeToInNode(node.innode());
452 }
453
454 if (node.has_nullnode()) {
455 return ProtoTypeToNullNode(node.nullnode());
456 }
457
458 if (node.has_intnode()) {
459 return TreeExprBuilder::MakeLiteral(node.intnode().value());
460 }
461
462 if (node.has_floatnode()) {
463 return TreeExprBuilder::MakeLiteral(node.floatnode().value());
464 }
465
466 if (node.has_longnode()) {
467 return TreeExprBuilder::MakeLiteral(node.longnode().value());
468 }
469
470 if (node.has_booleannode()) {
471 return TreeExprBuilder::MakeLiteral(node.booleannode().value());
472 }
473
474 if (node.has_doublenode()) {
475 return TreeExprBuilder::MakeLiteral(node.doublenode().value());
476 }
477
478 if (node.has_stringnode()) {
479 return TreeExprBuilder::MakeStringLiteral(node.stringnode().value());
480 }
481
482 if (node.has_binarynode()) {
483 return TreeExprBuilder::MakeBinaryLiteral(node.binarynode().value());
484 }
485
486 if (node.has_decimalnode()) {
487 std::string value = node.decimalnode().value();
488 gandiva::DecimalScalar128 literal(value, node.decimalnode().precision(),
489 node.decimalnode().scale());
490 return TreeExprBuilder::MakeDecimalLiteral(literal);
491 }
492 std::cerr << "Unknown node type in protobuf\n";
493 return nullptr;
494 }
495
ProtoTypeToExpression(const types::ExpressionRoot & root)496 ExpressionPtr ProtoTypeToExpression(const types::ExpressionRoot& root) {
497 NodePtr root_node = ProtoTypeToNode(root.root());
498 if (root_node == nullptr) {
499 std::cerr << "Unable to create expression node from expression protobuf\n";
500 return nullptr;
501 }
502
503 FieldPtr field = ProtoTypeToField(root.resulttype());
504 if (field == nullptr) {
505 std::cerr << "Unable to extra return field from expression protobuf\n";
506 return nullptr;
507 }
508
509 return TreeExprBuilder::MakeExpression(root_node, field);
510 }
511
ProtoTypeToCondition(const types::Condition & condition)512 ConditionPtr ProtoTypeToCondition(const types::Condition& condition) {
513 NodePtr root_node = ProtoTypeToNode(condition.root());
514 if (root_node == nullptr) {
515 return nullptr;
516 }
517
518 return TreeExprBuilder::MakeCondition(root_node);
519 }
520
ProtoTypeToSchema(const types::Schema & schema)521 SchemaPtr ProtoTypeToSchema(const types::Schema& schema) {
522 std::vector<FieldPtr> fields;
523
524 for (int i = 0; i < schema.columns_size(); i++) {
525 FieldPtr field = ProtoTypeToField(schema.columns(i));
526 if (field == nullptr) {
527 std::cerr << "Unable to extract arrow field from schema\n";
528 return nullptr;
529 }
530
531 fields.push_back(field);
532 }
533
534 return arrow::schema(fields);
535 }
536
537 // Common for both projector and filters.
538
ParseProtobuf(uint8_t * buf,int bufLen,google::protobuf::Message * msg)539 bool ParseProtobuf(uint8_t* buf, int bufLen, google::protobuf::Message* msg) {
540 google::protobuf::io::CodedInputStream cis(buf, bufLen);
541 cis.SetRecursionLimit(1000);
542 return msg->ParseFromCodedStream(&cis);
543 }
544
make_record_batch_with_buf_addrs(SchemaPtr schema,int num_rows,jlong * in_buf_addrs,jlong * in_buf_sizes,int in_bufs_len,std::shared_ptr<arrow::RecordBatch> * batch)545 Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows,
546 jlong* in_buf_addrs, jlong* in_buf_sizes,
547 int in_bufs_len,
548 std::shared_ptr<arrow::RecordBatch>* batch) {
549 std::vector<std::shared_ptr<arrow::ArrayData>> columns;
550 auto num_fields = schema->num_fields();
551 int buf_idx = 0;
552 int sz_idx = 0;
553
554 for (int i = 0; i < num_fields; i++) {
555 auto field = schema->field(i);
556 std::vector<std::shared_ptr<arrow::Buffer>> buffers;
557
558 if (buf_idx >= in_bufs_len) {
559 return Status::Invalid("insufficient number of in_buf_addrs");
560 }
561 jlong validity_addr = in_buf_addrs[buf_idx++];
562 jlong validity_size = in_buf_sizes[sz_idx++];
563 auto validity = std::shared_ptr<arrow::Buffer>(
564 new arrow::Buffer(reinterpret_cast<uint8_t*>(validity_addr), validity_size));
565 buffers.push_back(validity);
566
567 if (buf_idx >= in_bufs_len) {
568 return Status::Invalid("insufficient number of in_buf_addrs");
569 }
570 jlong value_addr = in_buf_addrs[buf_idx++];
571 jlong value_size = in_buf_sizes[sz_idx++];
572 auto data = std::shared_ptr<arrow::Buffer>(
573 new arrow::Buffer(reinterpret_cast<uint8_t*>(value_addr), value_size));
574 buffers.push_back(data);
575
576 if (arrow::is_binary_like(field->type()->id())) {
577 if (buf_idx >= in_bufs_len) {
578 return Status::Invalid("insufficient number of in_buf_addrs");
579 }
580
581 // add offsets buffer for variable-len fields.
582 jlong offsets_addr = in_buf_addrs[buf_idx++];
583 jlong offsets_size = in_buf_sizes[sz_idx++];
584 auto offsets = std::shared_ptr<arrow::Buffer>(
585 new arrow::Buffer(reinterpret_cast<uint8_t*>(offsets_addr), offsets_size));
586 buffers.push_back(offsets);
587 }
588
589 auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers));
590 columns.push_back(array_data);
591 }
592 *batch = arrow::RecordBatch::Make(schema, num_rows, columns);
593 return Status::OK();
594 }
595
596 // projector related functions.
releaseProjectorInput(jbyteArray schema_arr,jbyte * schema_bytes,jbyteArray exprs_arr,jbyte * exprs_bytes,JNIEnv * env)597 void releaseProjectorInput(jbyteArray schema_arr, jbyte* schema_bytes,
598 jbyteArray exprs_arr, jbyte* exprs_bytes, JNIEnv* env) {
599 env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
600 env->ReleaseByteArrayElements(exprs_arr, exprs_bytes, JNI_ABORT);
601 }
602
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector(JNIEnv * env,jobject obj,jbyteArray schema_arr,jbyteArray exprs_arr,jint selection_vector_type,jlong configuration_id)603 JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector(
604 JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray exprs_arr,
605 jint selection_vector_type, jlong configuration_id) {
606 jlong module_id = 0LL;
607 std::shared_ptr<Projector> projector;
608 std::shared_ptr<ProjectorHolder> holder;
609
610 types::Schema schema;
611 jsize schema_len = env->GetArrayLength(schema_arr);
612 jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
613
614 types::ExpressionList exprs;
615 jsize exprs_len = env->GetArrayLength(exprs_arr);
616 jbyte* exprs_bytes = env->GetByteArrayElements(exprs_arr, 0);
617
618 ExpressionVector expr_vector;
619 SchemaPtr schema_ptr;
620 FieldVector ret_types;
621 gandiva::Status status;
622 auto mode = gandiva::SelectionVector::MODE_NONE;
623
624 std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
625 std::stringstream ss;
626
627 if (config == nullptr) {
628 ss << "configuration is mandatory.";
629 releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
630 goto err_out;
631 }
632
633 if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
634 ss << "Unable to parse schema protobuf\n";
635 releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
636 goto err_out;
637 }
638
639 if (!ParseProtobuf(reinterpret_cast<uint8_t*>(exprs_bytes), exprs_len, &exprs)) {
640 releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
641 ss << "Unable to parse expressions protobuf\n";
642 goto err_out;
643 }
644
645 // convert types::Schema to arrow::Schema
646 schema_ptr = ProtoTypeToSchema(schema);
647 if (schema_ptr == nullptr) {
648 ss << "Unable to construct arrow schema object from schema protobuf\n";
649 releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
650 goto err_out;
651 }
652
653 // create Expression out of the list of exprs
654 for (int i = 0; i < exprs.exprs_size(); i++) {
655 ExpressionPtr root = ProtoTypeToExpression(exprs.exprs(i));
656
657 if (root == nullptr) {
658 ss << "Unable to construct expression object from expression protobuf\n";
659 releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
660 goto err_out;
661 }
662
663 expr_vector.push_back(root);
664 ret_types.push_back(root->result());
665 }
666
667 switch (selection_vector_type) {
668 case types::SV_NONE:
669 mode = gandiva::SelectionVector::MODE_NONE;
670 break;
671 case types::SV_INT16:
672 mode = gandiva::SelectionVector::MODE_UINT16;
673 break;
674 case types::SV_INT32:
675 mode = gandiva::SelectionVector::MODE_UINT32;
676 break;
677 }
678 // good to invoke the evaluator now
679 status = Projector::Make(schema_ptr, expr_vector, mode, config, &projector);
680
681 if (!status.ok()) {
682 ss << "Failed to make LLVM module due to " << status.message() << "\n";
683 releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
684 goto err_out;
685 }
686
687 // store the result in a map
688 holder = std::shared_ptr<ProjectorHolder>(
689 new ProjectorHolder(schema_ptr, ret_types, std::move(projector)));
690 module_id = projector_modules_.Insert(holder);
691 releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
692 return module_id;
693
694 err_out:
695 env->ThrowNew(gandiva_exception_, ss.str().c_str());
696 return module_id;
697 }
698
699 ///
700 /// \brief Resizable buffer which resizes by doing a callback into java.
701 ///
702 class JavaResizableBuffer : public arrow::ResizableBuffer {
703 public:
JavaResizableBuffer(JNIEnv * env,jobject jexpander,int32_t vector_idx,uint8_t * buffer,int32_t len)704 JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer,
705 int32_t len)
706 : ResizableBuffer(buffer, len),
707 env_(env),
708 jexpander_(jexpander),
709 vector_idx_(vector_idx) {
710 size_ = 0;
711 }
712
713 Status Resize(const int64_t new_size, bool shrink_to_fit) override;
714
Reserve(const int64_t new_capacity)715 Status Reserve(const int64_t new_capacity) override {
716 return Status::NotImplemented("reserve not implemented");
717 }
718
719 private:
720 JNIEnv* env_;
721 jobject jexpander_;
722 int32_t vector_idx_;
723 };
724
Resize(const int64_t new_size,bool shrink_to_fit)725 Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) {
726 if (shrink_to_fit == true) {
727 return Status::NotImplemented("shrink not implemented");
728 }
729
730 if (ARROW_PREDICT_TRUE(new_size < capacity())) {
731 // no need to expand.
732 size_ = new_size;
733 return Status::OK();
734 }
735
736 // callback into java to expand the buffer
737 jobject ret =
738 env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, new_size);
739 if (env_->ExceptionCheck()) {
740 env_->ExceptionDescribe();
741 env_->ExceptionClear();
742 return Status::OutOfMemory("buffer expand failed in java");
743 }
744
745 jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_);
746 jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_);
747 DCHECK_GE(ret_capacity, new_size);
748
749 data_ = reinterpret_cast<uint8_t*>(ret_address);
750 size_ = new_size;
751 capacity_ = ret_capacity;
752 return Status::OK();
753 }
754
755 #define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len) \
756 if (idx >= len) { \
757 status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \
758 break; \
759 }
760
761 JNIEXPORT void JNICALL
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector(JNIEnv * env,jobject object,jobject jexpander,jlong module_id,jint num_rows,jlongArray buf_addrs,jlongArray buf_sizes,jint sel_vec_type,jint sel_vec_rows,jlong sel_vec_addr,jlong sel_vec_size,jlongArray out_buf_addrs,jlongArray out_buf_sizes)762 Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector(
763 JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows,
764 jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows,
765 jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs,
766 jlongArray out_buf_sizes) {
767 Status status;
768 std::shared_ptr<ProjectorHolder> holder = projector_modules_.Lookup(module_id);
769 if (holder == nullptr) {
770 std::stringstream ss;
771 ss << "Unknown module id " << module_id;
772 env->ThrowNew(gandiva_exception_, ss.str().c_str());
773 return;
774 }
775
776 int in_bufs_len = env->GetArrayLength(buf_addrs);
777 if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
778 env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
779 return;
780 }
781
782 int out_bufs_len = env->GetArrayLength(out_buf_addrs);
783 if (out_bufs_len != env->GetArrayLength(out_buf_sizes)) {
784 env->ThrowNew(gandiva_exception_,
785 "mismatch in arraylen of out_buf_addrs and out_buf_sizes");
786 return;
787 }
788
789 jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
790 jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
791
792 jlong* out_bufs = env->GetLongArrayElements(out_buf_addrs, 0);
793 jlong* out_sizes = env->GetLongArrayElements(out_buf_sizes, 0);
794
795 do {
796 std::shared_ptr<arrow::RecordBatch> in_batch;
797 status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
798 in_buf_sizes, in_bufs_len, &in_batch);
799 if (!status.ok()) {
800 break;
801 }
802
803 std::shared_ptr<gandiva::SelectionVector> selection_vector;
804 auto selection_buffer = std::make_shared<arrow::Buffer>(
805 reinterpret_cast<uint8_t*>(sel_vec_addr), sel_vec_size);
806 int output_row_count = 0;
807 switch (sel_vec_type) {
808 case types::SV_NONE: {
809 output_row_count = num_rows;
810 break;
811 }
812 case types::SV_INT16: {
813 status = gandiva::SelectionVector::MakeImmutableInt16(
814 sel_vec_rows, selection_buffer, &selection_vector);
815 output_row_count = sel_vec_rows;
816 break;
817 }
818 case types::SV_INT32: {
819 status = gandiva::SelectionVector::MakeImmutableInt32(
820 sel_vec_rows, selection_buffer, &selection_vector);
821 output_row_count = sel_vec_rows;
822 break;
823 }
824 }
825 if (!status.ok()) {
826 break;
827 }
828
829 auto ret_types = holder->rettypes();
830 ArrayDataVector output;
831 int buf_idx = 0;
832 int sz_idx = 0;
833 int output_vector_idx = 0;
834 for (FieldPtr field : ret_types) {
835 std::vector<std::shared_ptr<arrow::Buffer>> buffers;
836
837 CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
838 uint8_t* validity_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
839 jlong bitmap_sz = out_sizes[sz_idx++];
840 buffers.push_back(std::make_shared<arrow::MutableBuffer>(validity_buf, bitmap_sz));
841
842 if (arrow::is_binary_like(field->type()->id())) {
843 CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
844 uint8_t* offsets_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
845 jlong offsets_sz = out_sizes[sz_idx++];
846 buffers.push_back(
847 std::make_shared<arrow::MutableBuffer>(offsets_buf, offsets_sz));
848 }
849
850 CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
851 uint8_t* value_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
852 jlong data_sz = out_sizes[sz_idx++];
853 if (arrow::is_binary_like(field->type()->id())) {
854 if (jexpander == nullptr) {
855 status = Status::Invalid(
856 "expression has variable len output columns, but the expander object is "
857 "null");
858 break;
859 }
860 buffers.push_back(std::make_shared<JavaResizableBuffer>(
861 env, jexpander, output_vector_idx, value_buf, data_sz));
862 } else {
863 buffers.push_back(std::make_shared<arrow::MutableBuffer>(value_buf, data_sz));
864 }
865
866 auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers);
867 output.push_back(array_data);
868 ++output_vector_idx;
869 }
870 if (!status.ok()) {
871 break;
872 }
873 status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output);
874 } while (0);
875
876 env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
877 env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
878 env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT);
879 env->ReleaseLongArrayElements(out_buf_sizes, out_sizes, JNI_ABORT);
880
881 if (!status.ok()) {
882 std::stringstream ss;
883 ss << "Evaluate returned " << status.message() << "\n";
884 env->ThrowNew(gandiva_exception_, status.message().c_str());
885 return;
886 }
887 }
888
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector(JNIEnv * env,jobject cls,jlong module_id)889 JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector(
890 JNIEnv* env, jobject cls, jlong module_id) {
891 projector_modules_.Erase(module_id);
892 }
893
894 // filter related functions.
releaseFilterInput(jbyteArray schema_arr,jbyte * schema_bytes,jbyteArray condition_arr,jbyte * condition_bytes,JNIEnv * env)895 void releaseFilterInput(jbyteArray schema_arr, jbyte* schema_bytes,
896 jbyteArray condition_arr, jbyte* condition_bytes, JNIEnv* env) {
897 env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
898 env->ReleaseByteArrayElements(condition_arr, condition_bytes, JNI_ABORT);
899 }
900
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter(JNIEnv * env,jobject obj,jbyteArray schema_arr,jbyteArray condition_arr,jlong configuration_id)901 JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter(
902 JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray condition_arr,
903 jlong configuration_id) {
904 jlong module_id = 0LL;
905 std::shared_ptr<Filter> filter;
906 std::shared_ptr<FilterHolder> holder;
907
908 types::Schema schema;
909 jsize schema_len = env->GetArrayLength(schema_arr);
910 jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
911
912 types::Condition condition;
913 jsize condition_len = env->GetArrayLength(condition_arr);
914 jbyte* condition_bytes = env->GetByteArrayElements(condition_arr, 0);
915
916 ConditionPtr condition_ptr;
917 SchemaPtr schema_ptr;
918 gandiva::Status status;
919
920 std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
921 std::stringstream ss;
922
923 if (config == nullptr) {
924 ss << "configuration is mandatory.";
925 releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
926 goto err_out;
927 }
928
929 if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
930 ss << "Unable to parse schema protobuf\n";
931 releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
932 goto err_out;
933 }
934
935 if (!ParseProtobuf(reinterpret_cast<uint8_t*>(condition_bytes), condition_len,
936 &condition)) {
937 ss << "Unable to parse condition protobuf\n";
938 releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
939 goto err_out;
940 }
941
942 // convert types::Schema to arrow::Schema
943 schema_ptr = ProtoTypeToSchema(schema);
944 if (schema_ptr == nullptr) {
945 ss << "Unable to construct arrow schema object from schema protobuf\n";
946 releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
947 goto err_out;
948 }
949
950 condition_ptr = ProtoTypeToCondition(condition);
951 if (condition_ptr == nullptr) {
952 ss << "Unable to construct condition object from condition protobuf\n";
953 releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
954 goto err_out;
955 }
956
957 // good to invoke the filter builder now
958 status = Filter::Make(schema_ptr, condition_ptr, config, &filter);
959 if (!status.ok()) {
960 ss << "Failed to make LLVM module due to " << status.message() << "\n";
961 releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
962 goto err_out;
963 }
964
965 // store the result in a map
966 holder = std::shared_ptr<FilterHolder>(new FilterHolder(schema_ptr, std::move(filter)));
967 module_id = filter_modules_.Insert(holder);
968 releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
969 return module_id;
970
971 err_out:
972 env->ThrowNew(gandiva_exception_, ss.str().c_str());
973 return module_id;
974 }
975
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter(JNIEnv * env,jobject cls,jlong module_id,jint num_rows,jlongArray buf_addrs,jlongArray buf_sizes,jint jselection_vector_type,jlong out_buf_addr,jlong out_buf_size)976 JNIEXPORT jint JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter(
977 JNIEnv* env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs,
978 jlongArray buf_sizes, jint jselection_vector_type, jlong out_buf_addr,
979 jlong out_buf_size) {
980 gandiva::Status status;
981 std::shared_ptr<FilterHolder> holder = filter_modules_.Lookup(module_id);
982 if (holder == nullptr) {
983 env->ThrowNew(gandiva_exception_, "Unknown module id\n");
984 return -1;
985 }
986
987 int in_bufs_len = env->GetArrayLength(buf_addrs);
988 if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
989 env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
990 return -1;
991 }
992
993 jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
994 jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
995 std::shared_ptr<gandiva::SelectionVector> selection_vector;
996
997 do {
998 std::shared_ptr<arrow::RecordBatch> in_batch;
999
1000 status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
1001 in_buf_sizes, in_bufs_len, &in_batch);
1002 if (!status.ok()) {
1003 break;
1004 }
1005
1006 auto selection_vector_type =
1007 static_cast<types::SelectionVectorType>(jselection_vector_type);
1008 auto out_buffer = std::make_shared<arrow::MutableBuffer>(
1009 reinterpret_cast<uint8_t*>(out_buf_addr), out_buf_size);
1010 switch (selection_vector_type) {
1011 case types::SV_INT16:
1012 status =
1013 gandiva::SelectionVector::MakeInt16(num_rows, out_buffer, &selection_vector);
1014 break;
1015 case types::SV_INT32:
1016 status =
1017 gandiva::SelectionVector::MakeInt32(num_rows, out_buffer, &selection_vector);
1018 break;
1019 default:
1020 status = gandiva::Status::Invalid("unknown selection vector type");
1021 }
1022 if (!status.ok()) {
1023 break;
1024 }
1025
1026 status = holder->filter()->Evaluate(*in_batch, selection_vector);
1027 } while (0);
1028
1029 env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
1030 env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
1031
1032 if (!status.ok()) {
1033 std::stringstream ss;
1034 ss << "Evaluate returned " << status.message() << "\n";
1035 env->ThrowNew(gandiva_exception_, status.message().c_str());
1036 return -1;
1037 } else {
1038 int64_t num_slots = selection_vector->GetNumSlots();
1039 // Check integer overflow
1040 if (num_slots > INT_MAX) {
1041 std::stringstream ss;
1042 ss << "The selection vector has " << num_slots
1043 << " slots, which is larger than the " << INT_MAX << " limit.\n";
1044 const std::string message = ss.str();
1045 env->ThrowNew(gandiva_exception_, message.c_str());
1046 return -1;
1047 }
1048 return static_cast<int>(num_slots);
1049 }
1050 }
1051
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter(JNIEnv * env,jobject cls,jlong module_id)1052 JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter(
1053 JNIEnv* env, jobject cls, jlong module_id) {
1054 filter_modules_.Erase(module_id);
1055 }
1056