1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <google/protobuf/io/coded_stream.h>
19 
20 #include <map>
21 #include <memory>
22 #include <mutex>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include <arrow/builder.h>
29 #include <arrow/record_batch.h>
30 #include <arrow/type.h>
31 
32 #include "Types.pb.h"
33 #include "gandiva/configuration.h"
34 #include "gandiva/decimal_scalar.h"
35 #include "gandiva/filter.h"
36 #include "gandiva/jni/config_holder.h"
37 #include "gandiva/jni/env_helper.h"
38 #include "gandiva/jni/id_to_module_map.h"
39 #include "gandiva/jni/module_holder.h"
40 #include "gandiva/projector.h"
41 #include "gandiva/selection_vector.h"
42 #include "gandiva/tree_expr_builder.h"
43 #include "jni/org_apache_arrow_gandiva_evaluator_JniWrapper.h"
44 
45 using gandiva::ConditionPtr;
46 using gandiva::DataTypePtr;
47 using gandiva::ExpressionPtr;
48 using gandiva::ExpressionVector;
49 using gandiva::FieldPtr;
50 using gandiva::FieldVector;
51 using gandiva::Filter;
52 using gandiva::NodePtr;
53 using gandiva::NodeVector;
54 using gandiva::Projector;
55 using gandiva::SchemaPtr;
56 using gandiva::Status;
57 using gandiva::TreeExprBuilder;
58 
59 using gandiva::ArrayDataVector;
60 using gandiva::ConfigHolder;
61 using gandiva::Configuration;
62 using gandiva::ConfigurationBuilder;
63 using gandiva::FilterHolder;
64 using gandiva::ProjectorHolder;
65 
66 // forward declarations
67 NodePtr ProtoTypeToNode(const types::TreeNode& node);
68 
69 static jint JNI_VERSION = JNI_VERSION_1_6;
70 
71 // extern refs - initialized for other modules.
72 jclass configuration_builder_class_;
73 
74 // refs for self.
75 static jclass gandiva_exception_;
76 static jclass vector_expander_class_;
77 static jclass vector_expander_ret_class_;
78 static jmethodID vector_expander_method_;
79 static jfieldID vector_expander_ret_address_;
80 static jfieldID vector_expander_ret_capacity_;
81 
82 // module maps
83 gandiva::IdToModuleMap<std::shared_ptr<ProjectorHolder>> projector_modules_;
84 gandiva::IdToModuleMap<std::shared_ptr<FilterHolder>> filter_modules_;
85 
JNI_OnLoad(JavaVM * vm,void * reserved)86 jint JNI_OnLoad(JavaVM* vm, void* reserved) {
87   JNIEnv* env;
88   if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
89     return JNI_ERR;
90   }
91   jclass local_configuration_builder_class_ =
92       env->FindClass("org/apache/arrow/gandiva/evaluator/ConfigurationBuilder");
93   configuration_builder_class_ =
94       (jclass)env->NewGlobalRef(local_configuration_builder_class_);
95   env->DeleteLocalRef(local_configuration_builder_class_);
96 
97   jclass localExceptionClass =
98       env->FindClass("org/apache/arrow/gandiva/exceptions/GandivaException");
99   gandiva_exception_ = (jclass)env->NewGlobalRef(localExceptionClass);
100   env->ExceptionDescribe();
101   env->DeleteLocalRef(localExceptionClass);
102 
103   jclass local_expander_class =
104       env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander");
105   vector_expander_class_ = (jclass)env->NewGlobalRef(local_expander_class);
106   env->DeleteLocalRef(local_expander_class);
107 
108   vector_expander_method_ = env->GetMethodID(
109       vector_expander_class_, "expandOutputVectorAtIndex",
110       "(IJ)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;");
111 
112   jclass local_expander_ret_class =
113       env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult");
114   vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class);
115   env->DeleteLocalRef(local_expander_ret_class);
116 
117   vector_expander_ret_address_ =
118       env->GetFieldID(vector_expander_ret_class_, "address", "J");
119   vector_expander_ret_capacity_ =
120       env->GetFieldID(vector_expander_ret_class_, "capacity", "J");
121   return JNI_VERSION;
122 }
123 
JNI_OnUnload(JavaVM * vm,void * reserved)124 void JNI_OnUnload(JavaVM* vm, void* reserved) {
125   JNIEnv* env;
126   vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);
127   env->DeleteGlobalRef(configuration_builder_class_);
128   env->DeleteGlobalRef(gandiva_exception_);
129   env->DeleteGlobalRef(vector_expander_class_);
130   env->DeleteGlobalRef(vector_expander_ret_class_);
131 }
132 
ProtoTypeToTime32(const types::ExtGandivaType & ext_type)133 DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) {
134   switch (ext_type.timeunit()) {
135     case types::SEC:
136       return arrow::time32(arrow::TimeUnit::SECOND);
137     case types::MILLISEC:
138       return arrow::time32(arrow::TimeUnit::MILLI);
139     default:
140       std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time32\n";
141       return nullptr;
142   }
143 }
144 
ProtoTypeToTime64(const types::ExtGandivaType & ext_type)145 DataTypePtr ProtoTypeToTime64(const types::ExtGandivaType& ext_type) {
146   switch (ext_type.timeunit()) {
147     case types::MICROSEC:
148       return arrow::time64(arrow::TimeUnit::MICRO);
149     case types::NANOSEC:
150       return arrow::time64(arrow::TimeUnit::NANO);
151     default:
152       std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time64\n";
153       return nullptr;
154   }
155 }
156 
ProtoTypeToTimestamp(const types::ExtGandivaType & ext_type)157 DataTypePtr ProtoTypeToTimestamp(const types::ExtGandivaType& ext_type) {
158   switch (ext_type.timeunit()) {
159     case types::SEC:
160       return arrow::timestamp(arrow::TimeUnit::SECOND);
161     case types::MILLISEC:
162       return arrow::timestamp(arrow::TimeUnit::MILLI);
163     case types::MICROSEC:
164       return arrow::timestamp(arrow::TimeUnit::MICRO);
165     case types::NANOSEC:
166       return arrow::timestamp(arrow::TimeUnit::NANO);
167     default:
168       std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for timestamp\n";
169       return nullptr;
170   }
171 }
172 
ProtoTypeToInterval(const types::ExtGandivaType & ext_type)173 DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) {
174   switch (ext_type.intervaltype()) {
175     case types::YEAR_MONTH:
176       return arrow::month_interval();
177     case types::DAY_TIME:
178       return arrow::day_time_interval();
179     default:
180       std::cerr << "Unknown interval type: " << ext_type.intervaltype() << "\n";
181       return nullptr;
182   }
183 }
184 
ProtoTypeToDataType(const types::ExtGandivaType & ext_type)185 DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) {
186   switch (ext_type.type()) {
187     case types::NONE:
188       return arrow::null();
189     case types::BOOL:
190       return arrow::boolean();
191     case types::UINT8:
192       return arrow::uint8();
193     case types::INT8:
194       return arrow::int8();
195     case types::UINT16:
196       return arrow::uint16();
197     case types::INT16:
198       return arrow::int16();
199     case types::UINT32:
200       return arrow::uint32();
201     case types::INT32:
202       return arrow::int32();
203     case types::UINT64:
204       return arrow::uint64();
205     case types::INT64:
206       return arrow::int64();
207     case types::HALF_FLOAT:
208       return arrow::float16();
209     case types::FLOAT:
210       return arrow::float32();
211     case types::DOUBLE:
212       return arrow::float64();
213     case types::UTF8:
214       return arrow::utf8();
215     case types::BINARY:
216       return arrow::binary();
217     case types::DATE32:
218       return arrow::date32();
219     case types::DATE64:
220       return arrow::date64();
221     case types::DECIMAL:
222       // TODO: error handling
223       return arrow::decimal(ext_type.precision(), ext_type.scale());
224     case types::TIME32:
225       return ProtoTypeToTime32(ext_type);
226     case types::TIME64:
227       return ProtoTypeToTime64(ext_type);
228     case types::TIMESTAMP:
229       return ProtoTypeToTimestamp(ext_type);
230     case types::INTERVAL:
231       return ProtoTypeToInterval(ext_type);
232     case types::FIXED_SIZE_BINARY:
233     case types::LIST:
234     case types::STRUCT:
235     case types::UNION:
236     case types::DICTIONARY:
237     case types::MAP:
238       std::cerr << "Unhandled data type: " << ext_type.type() << "\n";
239       return nullptr;
240 
241     default:
242       std::cerr << "Unknown data type: " << ext_type.type() << "\n";
243       return nullptr;
244   }
245 }
246 
ProtoTypeToField(const types::Field & f)247 FieldPtr ProtoTypeToField(const types::Field& f) {
248   const std::string& name = f.name();
249   DataTypePtr type = ProtoTypeToDataType(f.type());
250   bool nullable = true;
251   if (f.has_nullable()) {
252     nullable = f.nullable();
253   }
254 
255   return field(name, type, nullable);
256 }
257 
ProtoTypeToFieldNode(const types::FieldNode & node)258 NodePtr ProtoTypeToFieldNode(const types::FieldNode& node) {
259   FieldPtr field_ptr = ProtoTypeToField(node.field());
260   if (field_ptr == nullptr) {
261     std::cerr << "Unable to create field node from protobuf\n";
262     return nullptr;
263   }
264 
265   return TreeExprBuilder::MakeField(field_ptr);
266 }
267 
ProtoTypeToFnNode(const types::FunctionNode & node)268 NodePtr ProtoTypeToFnNode(const types::FunctionNode& node) {
269   const std::string& name = node.functionname();
270   NodeVector children;
271 
272   for (int i = 0; i < node.inargs_size(); i++) {
273     const types::TreeNode& arg = node.inargs(i);
274 
275     NodePtr n = ProtoTypeToNode(arg);
276     if (n == nullptr) {
277       std::cerr << "Unable to create argument for function: " << name << "\n";
278       return nullptr;
279     }
280 
281     children.push_back(n);
282   }
283 
284   DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
285   if (return_type == nullptr) {
286     std::cerr << "Unknown return type for function: " << name << "\n";
287     return nullptr;
288   }
289 
290   return TreeExprBuilder::MakeFunction(name, children, return_type);
291 }
292 
ProtoTypeToIfNode(const types::IfNode & node)293 NodePtr ProtoTypeToIfNode(const types::IfNode& node) {
294   NodePtr cond = ProtoTypeToNode(node.cond());
295   if (cond == nullptr) {
296     std::cerr << "Unable to create cond node for if node\n";
297     return nullptr;
298   }
299 
300   NodePtr then_node = ProtoTypeToNode(node.thennode());
301   if (then_node == nullptr) {
302     std::cerr << "Unable to create then node for if node\n";
303     return nullptr;
304   }
305 
306   NodePtr else_node = ProtoTypeToNode(node.elsenode());
307   if (else_node == nullptr) {
308     std::cerr << "Unable to create else node for if node\n";
309     return nullptr;
310   }
311 
312   DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
313   if (return_type == nullptr) {
314     std::cerr << "Unknown return type for if node\n";
315     return nullptr;
316   }
317 
318   return TreeExprBuilder::MakeIf(cond, then_node, else_node, return_type);
319 }
320 
ProtoTypeToAndNode(const types::AndNode & node)321 NodePtr ProtoTypeToAndNode(const types::AndNode& node) {
322   NodeVector children;
323 
324   for (int i = 0; i < node.args_size(); i++) {
325     const types::TreeNode& arg = node.args(i);
326 
327     NodePtr n = ProtoTypeToNode(arg);
328     if (n == nullptr) {
329       std::cerr << "Unable to create argument for boolean and\n";
330       return nullptr;
331     }
332     children.push_back(n);
333   }
334   return TreeExprBuilder::MakeAnd(children);
335 }
336 
ProtoTypeToOrNode(const types::OrNode & node)337 NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
338   NodeVector children;
339 
340   for (int i = 0; i < node.args_size(); i++) {
341     const types::TreeNode& arg = node.args(i);
342 
343     NodePtr n = ProtoTypeToNode(arg);
344     if (n == nullptr) {
345       std::cerr << "Unable to create argument for boolean or\n";
346       return nullptr;
347     }
348     children.push_back(n);
349   }
350   return TreeExprBuilder::MakeOr(children);
351 }
352 
ProtoTypeToInNode(const types::InNode & node)353 NodePtr ProtoTypeToInNode(const types::InNode& node) {
354   NodePtr field = ProtoTypeToNode(node.node());
355 
356   if (node.has_intvalues()) {
357     std::unordered_set<int32_t> int_values;
358     for (int i = 0; i < node.intvalues().intvalues_size(); i++) {
359       int_values.insert(node.intvalues().intvalues(i).value());
360     }
361     return TreeExprBuilder::MakeInExpressionInt32(field, int_values);
362   }
363 
364   if (node.has_longvalues()) {
365     std::unordered_set<int64_t> long_values;
366     for (int i = 0; i < node.longvalues().longvalues_size(); i++) {
367       long_values.insert(node.longvalues().longvalues(i).value());
368     }
369     return TreeExprBuilder::MakeInExpressionInt64(field, long_values);
370   }
371 
372   if (node.has_decimalvalues()) {
373     std::unordered_set<gandiva::DecimalScalar128> decimal_values;
374     for (int i = 0; i < node.decimalvalues().decimalvalues_size(); i++) {
375       decimal_values.insert(
376           gandiva::DecimalScalar128(node.decimalvalues().decimalvalues(i).value(),
377                                     node.decimalvalues().decimalvalues(i).precision(),
378                                     node.decimalvalues().decimalvalues(i).scale()));
379     }
380     return TreeExprBuilder::MakeInExpressionDecimal(field, decimal_values);
381   }
382 
383   if (node.has_floatvalues()) {
384     std::unordered_set<float> float_values;
385     for (int i = 0; i < node.floatvalues().floatvalues_size(); i++) {
386       float_values.insert(node.floatvalues().floatvalues(i).value());
387     }
388     return TreeExprBuilder::MakeInExpressionFloat(field, float_values);
389   }
390 
391   if (node.has_doublevalues()) {
392     std::unordered_set<double> double_values;
393     for (int i = 0; i < node.doublevalues().doublevalues_size(); i++) {
394       double_values.insert(node.doublevalues().doublevalues(i).value());
395     }
396     return TreeExprBuilder::MakeInExpressionDouble(field, double_values);
397   }
398 
399   if (node.has_stringvalues()) {
400     std::unordered_set<std::string> stringvalues;
401     for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
402       stringvalues.insert(node.stringvalues().stringvalues(i).value());
403     }
404     return TreeExprBuilder::MakeInExpressionString(field, stringvalues);
405   }
406 
407   if (node.has_binaryvalues()) {
408     std::unordered_set<std::string> stringvalues;
409     for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) {
410       stringvalues.insert(node.binaryvalues().binaryvalues(i).value());
411     }
412     return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues);
413   }
414   // not supported yet.
415   std::cerr << "Unknown constant type for in expression.\n";
416   return nullptr;
417 }
418 
ProtoTypeToNullNode(const types::NullNode & node)419 NodePtr ProtoTypeToNullNode(const types::NullNode& node) {
420   DataTypePtr data_type = ProtoTypeToDataType(node.type());
421   if (data_type == nullptr) {
422     std::cerr << "Unknown type " << data_type->ToString() << " for null node\n";
423     return nullptr;
424   }
425 
426   return TreeExprBuilder::MakeNull(data_type);
427 }
428 
ProtoTypeToNode(const types::TreeNode & node)429 NodePtr ProtoTypeToNode(const types::TreeNode& node) {
430   if (node.has_fieldnode()) {
431     return ProtoTypeToFieldNode(node.fieldnode());
432   }
433 
434   if (node.has_fnnode()) {
435     return ProtoTypeToFnNode(node.fnnode());
436   }
437 
438   if (node.has_ifnode()) {
439     return ProtoTypeToIfNode(node.ifnode());
440   }
441 
442   if (node.has_andnode()) {
443     return ProtoTypeToAndNode(node.andnode());
444   }
445 
446   if (node.has_ornode()) {
447     return ProtoTypeToOrNode(node.ornode());
448   }
449 
450   if (node.has_innode()) {
451     return ProtoTypeToInNode(node.innode());
452   }
453 
454   if (node.has_nullnode()) {
455     return ProtoTypeToNullNode(node.nullnode());
456   }
457 
458   if (node.has_intnode()) {
459     return TreeExprBuilder::MakeLiteral(node.intnode().value());
460   }
461 
462   if (node.has_floatnode()) {
463     return TreeExprBuilder::MakeLiteral(node.floatnode().value());
464   }
465 
466   if (node.has_longnode()) {
467     return TreeExprBuilder::MakeLiteral(node.longnode().value());
468   }
469 
470   if (node.has_booleannode()) {
471     return TreeExprBuilder::MakeLiteral(node.booleannode().value());
472   }
473 
474   if (node.has_doublenode()) {
475     return TreeExprBuilder::MakeLiteral(node.doublenode().value());
476   }
477 
478   if (node.has_stringnode()) {
479     return TreeExprBuilder::MakeStringLiteral(node.stringnode().value());
480   }
481 
482   if (node.has_binarynode()) {
483     return TreeExprBuilder::MakeBinaryLiteral(node.binarynode().value());
484   }
485 
486   if (node.has_decimalnode()) {
487     std::string value = node.decimalnode().value();
488     gandiva::DecimalScalar128 literal(value, node.decimalnode().precision(),
489                                       node.decimalnode().scale());
490     return TreeExprBuilder::MakeDecimalLiteral(literal);
491   }
492   std::cerr << "Unknown node type in protobuf\n";
493   return nullptr;
494 }
495 
ProtoTypeToExpression(const types::ExpressionRoot & root)496 ExpressionPtr ProtoTypeToExpression(const types::ExpressionRoot& root) {
497   NodePtr root_node = ProtoTypeToNode(root.root());
498   if (root_node == nullptr) {
499     std::cerr << "Unable to create expression node from expression protobuf\n";
500     return nullptr;
501   }
502 
503   FieldPtr field = ProtoTypeToField(root.resulttype());
504   if (field == nullptr) {
505     std::cerr << "Unable to extra return field from expression protobuf\n";
506     return nullptr;
507   }
508 
509   return TreeExprBuilder::MakeExpression(root_node, field);
510 }
511 
ProtoTypeToCondition(const types::Condition & condition)512 ConditionPtr ProtoTypeToCondition(const types::Condition& condition) {
513   NodePtr root_node = ProtoTypeToNode(condition.root());
514   if (root_node == nullptr) {
515     return nullptr;
516   }
517 
518   return TreeExprBuilder::MakeCondition(root_node);
519 }
520 
ProtoTypeToSchema(const types::Schema & schema)521 SchemaPtr ProtoTypeToSchema(const types::Schema& schema) {
522   std::vector<FieldPtr> fields;
523 
524   for (int i = 0; i < schema.columns_size(); i++) {
525     FieldPtr field = ProtoTypeToField(schema.columns(i));
526     if (field == nullptr) {
527       std::cerr << "Unable to extract arrow field from schema\n";
528       return nullptr;
529     }
530 
531     fields.push_back(field);
532   }
533 
534   return arrow::schema(fields);
535 }
536 
537 // Common for both projector and filters.
538 
ParseProtobuf(uint8_t * buf,int bufLen,google::protobuf::Message * msg)539 bool ParseProtobuf(uint8_t* buf, int bufLen, google::protobuf::Message* msg) {
540   google::protobuf::io::CodedInputStream cis(buf, bufLen);
541   cis.SetRecursionLimit(1000);
542   return msg->ParseFromCodedStream(&cis);
543 }
544 
make_record_batch_with_buf_addrs(SchemaPtr schema,int num_rows,jlong * in_buf_addrs,jlong * in_buf_sizes,int in_bufs_len,std::shared_ptr<arrow::RecordBatch> * batch)545 Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows,
546                                         jlong* in_buf_addrs, jlong* in_buf_sizes,
547                                         int in_bufs_len,
548                                         std::shared_ptr<arrow::RecordBatch>* batch) {
549   std::vector<std::shared_ptr<arrow::ArrayData>> columns;
550   auto num_fields = schema->num_fields();
551   int buf_idx = 0;
552   int sz_idx = 0;
553 
554   for (int i = 0; i < num_fields; i++) {
555     auto field = schema->field(i);
556     std::vector<std::shared_ptr<arrow::Buffer>> buffers;
557 
558     if (buf_idx >= in_bufs_len) {
559       return Status::Invalid("insufficient number of in_buf_addrs");
560     }
561     jlong validity_addr = in_buf_addrs[buf_idx++];
562     jlong validity_size = in_buf_sizes[sz_idx++];
563     auto validity = std::shared_ptr<arrow::Buffer>(
564         new arrow::Buffer(reinterpret_cast<uint8_t*>(validity_addr), validity_size));
565     buffers.push_back(validity);
566 
567     if (buf_idx >= in_bufs_len) {
568       return Status::Invalid("insufficient number of in_buf_addrs");
569     }
570     jlong value_addr = in_buf_addrs[buf_idx++];
571     jlong value_size = in_buf_sizes[sz_idx++];
572     auto data = std::shared_ptr<arrow::Buffer>(
573         new arrow::Buffer(reinterpret_cast<uint8_t*>(value_addr), value_size));
574     buffers.push_back(data);
575 
576     if (arrow::is_binary_like(field->type()->id())) {
577       if (buf_idx >= in_bufs_len) {
578         return Status::Invalid("insufficient number of in_buf_addrs");
579       }
580 
581       // add offsets buffer for variable-len fields.
582       jlong offsets_addr = in_buf_addrs[buf_idx++];
583       jlong offsets_size = in_buf_sizes[sz_idx++];
584       auto offsets = std::shared_ptr<arrow::Buffer>(
585           new arrow::Buffer(reinterpret_cast<uint8_t*>(offsets_addr), offsets_size));
586       buffers.push_back(offsets);
587     }
588 
589     auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers));
590     columns.push_back(array_data);
591   }
592   *batch = arrow::RecordBatch::Make(schema, num_rows, columns);
593   return Status::OK();
594 }
595 
596 // projector related functions.
releaseProjectorInput(jbyteArray schema_arr,jbyte * schema_bytes,jbyteArray exprs_arr,jbyte * exprs_bytes,JNIEnv * env)597 void releaseProjectorInput(jbyteArray schema_arr, jbyte* schema_bytes,
598                            jbyteArray exprs_arr, jbyte* exprs_bytes, JNIEnv* env) {
599   env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
600   env->ReleaseByteArrayElements(exprs_arr, exprs_bytes, JNI_ABORT);
601 }
602 
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector(JNIEnv * env,jobject obj,jbyteArray schema_arr,jbyteArray exprs_arr,jint selection_vector_type,jlong configuration_id)603 JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector(
604     JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray exprs_arr,
605     jint selection_vector_type, jlong configuration_id) {
606   jlong module_id = 0LL;
607   std::shared_ptr<Projector> projector;
608   std::shared_ptr<ProjectorHolder> holder;
609 
610   types::Schema schema;
611   jsize schema_len = env->GetArrayLength(schema_arr);
612   jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
613 
614   types::ExpressionList exprs;
615   jsize exprs_len = env->GetArrayLength(exprs_arr);
616   jbyte* exprs_bytes = env->GetByteArrayElements(exprs_arr, 0);
617 
618   ExpressionVector expr_vector;
619   SchemaPtr schema_ptr;
620   FieldVector ret_types;
621   gandiva::Status status;
622   auto mode = gandiva::SelectionVector::MODE_NONE;
623 
624   std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
625   std::stringstream ss;
626 
627   if (config == nullptr) {
628     ss << "configuration is mandatory.";
629     releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
630     goto err_out;
631   }
632 
633   if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
634     ss << "Unable to parse schema protobuf\n";
635     releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
636     goto err_out;
637   }
638 
639   if (!ParseProtobuf(reinterpret_cast<uint8_t*>(exprs_bytes), exprs_len, &exprs)) {
640     releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
641     ss << "Unable to parse expressions protobuf\n";
642     goto err_out;
643   }
644 
645   // convert types::Schema to arrow::Schema
646   schema_ptr = ProtoTypeToSchema(schema);
647   if (schema_ptr == nullptr) {
648     ss << "Unable to construct arrow schema object from schema protobuf\n";
649     releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
650     goto err_out;
651   }
652 
653   // create Expression out of the list of exprs
654   for (int i = 0; i < exprs.exprs_size(); i++) {
655     ExpressionPtr root = ProtoTypeToExpression(exprs.exprs(i));
656 
657     if (root == nullptr) {
658       ss << "Unable to construct expression object from expression protobuf\n";
659       releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
660       goto err_out;
661     }
662 
663     expr_vector.push_back(root);
664     ret_types.push_back(root->result());
665   }
666 
667   switch (selection_vector_type) {
668     case types::SV_NONE:
669       mode = gandiva::SelectionVector::MODE_NONE;
670       break;
671     case types::SV_INT16:
672       mode = gandiva::SelectionVector::MODE_UINT16;
673       break;
674     case types::SV_INT32:
675       mode = gandiva::SelectionVector::MODE_UINT32;
676       break;
677   }
678   // good to invoke the evaluator now
679   status = Projector::Make(schema_ptr, expr_vector, mode, config, &projector);
680 
681   if (!status.ok()) {
682     ss << "Failed to make LLVM module due to " << status.message() << "\n";
683     releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
684     goto err_out;
685   }
686 
687   // store the result in a map
688   holder = std::shared_ptr<ProjectorHolder>(
689       new ProjectorHolder(schema_ptr, ret_types, std::move(projector)));
690   module_id = projector_modules_.Insert(holder);
691   releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
692   return module_id;
693 
694 err_out:
695   env->ThrowNew(gandiva_exception_, ss.str().c_str());
696   return module_id;
697 }
698 
699 ///
700 /// \brief Resizable buffer which resizes by doing a callback into java.
701 ///
702 class JavaResizableBuffer : public arrow::ResizableBuffer {
703  public:
JavaResizableBuffer(JNIEnv * env,jobject jexpander,int32_t vector_idx,uint8_t * buffer,int32_t len)704   JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer,
705                       int32_t len)
706       : ResizableBuffer(buffer, len),
707         env_(env),
708         jexpander_(jexpander),
709         vector_idx_(vector_idx) {
710     size_ = 0;
711   }
712 
713   Status Resize(const int64_t new_size, bool shrink_to_fit) override;
714 
Reserve(const int64_t new_capacity)715   Status Reserve(const int64_t new_capacity) override {
716     return Status::NotImplemented("reserve not implemented");
717   }
718 
719  private:
720   JNIEnv* env_;
721   jobject jexpander_;
722   int32_t vector_idx_;
723 };
724 
Resize(const int64_t new_size,bool shrink_to_fit)725 Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) {
726   if (shrink_to_fit == true) {
727     return Status::NotImplemented("shrink not implemented");
728   }
729 
730   if (ARROW_PREDICT_TRUE(new_size < capacity())) {
731     // no need to expand.
732     size_ = new_size;
733     return Status::OK();
734   }
735 
736   // callback into java to expand the buffer
737   jobject ret =
738       env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, new_size);
739   if (env_->ExceptionCheck()) {
740     env_->ExceptionDescribe();
741     env_->ExceptionClear();
742     return Status::OutOfMemory("buffer expand failed in java");
743   }
744 
745   jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_);
746   jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_);
747   DCHECK_GE(ret_capacity, new_size);
748 
749   data_ = reinterpret_cast<uint8_t*>(ret_address);
750   size_ = new_size;
751   capacity_ = ret_capacity;
752   return Status::OK();
753 }
754 
755 #define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len)                               \
756   if (idx >= len) {                                                            \
757     status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \
758     break;                                                                     \
759   }
760 
761 JNIEXPORT void JNICALL
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector(JNIEnv * env,jobject object,jobject jexpander,jlong module_id,jint num_rows,jlongArray buf_addrs,jlongArray buf_sizes,jint sel_vec_type,jint sel_vec_rows,jlong sel_vec_addr,jlong sel_vec_size,jlongArray out_buf_addrs,jlongArray out_buf_sizes)762 Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector(
763     JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows,
764     jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows,
765     jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs,
766     jlongArray out_buf_sizes) {
767   Status status;
768   std::shared_ptr<ProjectorHolder> holder = projector_modules_.Lookup(module_id);
769   if (holder == nullptr) {
770     std::stringstream ss;
771     ss << "Unknown module id " << module_id;
772     env->ThrowNew(gandiva_exception_, ss.str().c_str());
773     return;
774   }
775 
776   int in_bufs_len = env->GetArrayLength(buf_addrs);
777   if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
778     env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
779     return;
780   }
781 
782   int out_bufs_len = env->GetArrayLength(out_buf_addrs);
783   if (out_bufs_len != env->GetArrayLength(out_buf_sizes)) {
784     env->ThrowNew(gandiva_exception_,
785                   "mismatch in arraylen of out_buf_addrs and out_buf_sizes");
786     return;
787   }
788 
789   jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
790   jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
791 
792   jlong* out_bufs = env->GetLongArrayElements(out_buf_addrs, 0);
793   jlong* out_sizes = env->GetLongArrayElements(out_buf_sizes, 0);
794 
795   do {
796     std::shared_ptr<arrow::RecordBatch> in_batch;
797     status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
798                                               in_buf_sizes, in_bufs_len, &in_batch);
799     if (!status.ok()) {
800       break;
801     }
802 
803     std::shared_ptr<gandiva::SelectionVector> selection_vector;
804     auto selection_buffer = std::make_shared<arrow::Buffer>(
805         reinterpret_cast<uint8_t*>(sel_vec_addr), sel_vec_size);
806     int output_row_count = 0;
807     switch (sel_vec_type) {
808       case types::SV_NONE: {
809         output_row_count = num_rows;
810         break;
811       }
812       case types::SV_INT16: {
813         status = gandiva::SelectionVector::MakeImmutableInt16(
814             sel_vec_rows, selection_buffer, &selection_vector);
815         output_row_count = sel_vec_rows;
816         break;
817       }
818       case types::SV_INT32: {
819         status = gandiva::SelectionVector::MakeImmutableInt32(
820             sel_vec_rows, selection_buffer, &selection_vector);
821         output_row_count = sel_vec_rows;
822         break;
823       }
824     }
825     if (!status.ok()) {
826       break;
827     }
828 
829     auto ret_types = holder->rettypes();
830     ArrayDataVector output;
831     int buf_idx = 0;
832     int sz_idx = 0;
833     int output_vector_idx = 0;
834     for (FieldPtr field : ret_types) {
835       std::vector<std::shared_ptr<arrow::Buffer>> buffers;
836 
837       CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
838       uint8_t* validity_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
839       jlong bitmap_sz = out_sizes[sz_idx++];
840       buffers.push_back(std::make_shared<arrow::MutableBuffer>(validity_buf, bitmap_sz));
841 
842       if (arrow::is_binary_like(field->type()->id())) {
843         CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
844         uint8_t* offsets_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
845         jlong offsets_sz = out_sizes[sz_idx++];
846         buffers.push_back(
847             std::make_shared<arrow::MutableBuffer>(offsets_buf, offsets_sz));
848       }
849 
850       CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
851       uint8_t* value_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
852       jlong data_sz = out_sizes[sz_idx++];
853       if (arrow::is_binary_like(field->type()->id())) {
854         if (jexpander == nullptr) {
855           status = Status::Invalid(
856               "expression has variable len output columns, but the expander object is "
857               "null");
858           break;
859         }
860         buffers.push_back(std::make_shared<JavaResizableBuffer>(
861             env, jexpander, output_vector_idx, value_buf, data_sz));
862       } else {
863         buffers.push_back(std::make_shared<arrow::MutableBuffer>(value_buf, data_sz));
864       }
865 
866       auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers);
867       output.push_back(array_data);
868       ++output_vector_idx;
869     }
870     if (!status.ok()) {
871       break;
872     }
873     status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output);
874   } while (0);
875 
876   env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
877   env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
878   env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT);
879   env->ReleaseLongArrayElements(out_buf_sizes, out_sizes, JNI_ABORT);
880 
881   if (!status.ok()) {
882     std::stringstream ss;
883     ss << "Evaluate returned " << status.message() << "\n";
884     env->ThrowNew(gandiva_exception_, status.message().c_str());
885     return;
886   }
887 }
888 
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector(JNIEnv * env,jobject cls,jlong module_id)889 JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector(
890     JNIEnv* env, jobject cls, jlong module_id) {
891   projector_modules_.Erase(module_id);
892 }
893 
894 // filter related functions.
releaseFilterInput(jbyteArray schema_arr,jbyte * schema_bytes,jbyteArray condition_arr,jbyte * condition_bytes,JNIEnv * env)895 void releaseFilterInput(jbyteArray schema_arr, jbyte* schema_bytes,
896                         jbyteArray condition_arr, jbyte* condition_bytes, JNIEnv* env) {
897   env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
898   env->ReleaseByteArrayElements(condition_arr, condition_bytes, JNI_ABORT);
899 }
900 
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter(JNIEnv * env,jobject obj,jbyteArray schema_arr,jbyteArray condition_arr,jlong configuration_id)901 JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter(
902     JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray condition_arr,
903     jlong configuration_id) {
904   jlong module_id = 0LL;
905   std::shared_ptr<Filter> filter;
906   std::shared_ptr<FilterHolder> holder;
907 
908   types::Schema schema;
909   jsize schema_len = env->GetArrayLength(schema_arr);
910   jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
911 
912   types::Condition condition;
913   jsize condition_len = env->GetArrayLength(condition_arr);
914   jbyte* condition_bytes = env->GetByteArrayElements(condition_arr, 0);
915 
916   ConditionPtr condition_ptr;
917   SchemaPtr schema_ptr;
918   gandiva::Status status;
919 
920   std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
921   std::stringstream ss;
922 
923   if (config == nullptr) {
924     ss << "configuration is mandatory.";
925     releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
926     goto err_out;
927   }
928 
929   if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
930     ss << "Unable to parse schema protobuf\n";
931     releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
932     goto err_out;
933   }
934 
935   if (!ParseProtobuf(reinterpret_cast<uint8_t*>(condition_bytes), condition_len,
936                      &condition)) {
937     ss << "Unable to parse condition protobuf\n";
938     releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
939     goto err_out;
940   }
941 
942   // convert types::Schema to arrow::Schema
943   schema_ptr = ProtoTypeToSchema(schema);
944   if (schema_ptr == nullptr) {
945     ss << "Unable to construct arrow schema object from schema protobuf\n";
946     releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
947     goto err_out;
948   }
949 
950   condition_ptr = ProtoTypeToCondition(condition);
951   if (condition_ptr == nullptr) {
952     ss << "Unable to construct condition object from condition protobuf\n";
953     releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
954     goto err_out;
955   }
956 
957   // good to invoke the filter builder now
958   status = Filter::Make(schema_ptr, condition_ptr, config, &filter);
959   if (!status.ok()) {
960     ss << "Failed to make LLVM module due to " << status.message() << "\n";
961     releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
962     goto err_out;
963   }
964 
965   // store the result in a map
966   holder = std::shared_ptr<FilterHolder>(new FilterHolder(schema_ptr, std::move(filter)));
967   module_id = filter_modules_.Insert(holder);
968   releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
969   return module_id;
970 
971 err_out:
972   env->ThrowNew(gandiva_exception_, ss.str().c_str());
973   return module_id;
974 }
975 
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter(JNIEnv * env,jobject cls,jlong module_id,jint num_rows,jlongArray buf_addrs,jlongArray buf_sizes,jint jselection_vector_type,jlong out_buf_addr,jlong out_buf_size)976 JNIEXPORT jint JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter(
977     JNIEnv* env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs,
978     jlongArray buf_sizes, jint jselection_vector_type, jlong out_buf_addr,
979     jlong out_buf_size) {
980   gandiva::Status status;
981   std::shared_ptr<FilterHolder> holder = filter_modules_.Lookup(module_id);
982   if (holder == nullptr) {
983     env->ThrowNew(gandiva_exception_, "Unknown module id\n");
984     return -1;
985   }
986 
987   int in_bufs_len = env->GetArrayLength(buf_addrs);
988   if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
989     env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
990     return -1;
991   }
992 
993   jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
994   jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
995   std::shared_ptr<gandiva::SelectionVector> selection_vector;
996 
997   do {
998     std::shared_ptr<arrow::RecordBatch> in_batch;
999 
1000     status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
1001                                               in_buf_sizes, in_bufs_len, &in_batch);
1002     if (!status.ok()) {
1003       break;
1004     }
1005 
1006     auto selection_vector_type =
1007         static_cast<types::SelectionVectorType>(jselection_vector_type);
1008     auto out_buffer = std::make_shared<arrow::MutableBuffer>(
1009         reinterpret_cast<uint8_t*>(out_buf_addr), out_buf_size);
1010     switch (selection_vector_type) {
1011       case types::SV_INT16:
1012         status =
1013             gandiva::SelectionVector::MakeInt16(num_rows, out_buffer, &selection_vector);
1014         break;
1015       case types::SV_INT32:
1016         status =
1017             gandiva::SelectionVector::MakeInt32(num_rows, out_buffer, &selection_vector);
1018         break;
1019       default:
1020         status = gandiva::Status::Invalid("unknown selection vector type");
1021     }
1022     if (!status.ok()) {
1023       break;
1024     }
1025 
1026     status = holder->filter()->Evaluate(*in_batch, selection_vector);
1027   } while (0);
1028 
1029   env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
1030   env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
1031 
1032   if (!status.ok()) {
1033     std::stringstream ss;
1034     ss << "Evaluate returned " << status.message() << "\n";
1035     env->ThrowNew(gandiva_exception_, status.message().c_str());
1036     return -1;
1037   } else {
1038     int64_t num_slots = selection_vector->GetNumSlots();
1039     // Check integer overflow
1040     if (num_slots > INT_MAX) {
1041       std::stringstream ss;
1042       ss << "The selection vector has " << num_slots
1043          << " slots, which is larger than the " << INT_MAX << " limit.\n";
1044       const std::string message = ss.str();
1045       env->ThrowNew(gandiva_exception_, message.c_str());
1046       return -1;
1047     }
1048     return static_cast<int>(num_slots);
1049   }
1050 }
1051 
Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter(JNIEnv * env,jobject cls,jlong module_id)1052 JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter(
1053     JNIEnv* env, jobject cls, jlong module_id) {
1054   filter_modules_.Erase(module_id);
1055 }
1056