1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "gandiva/llvm_generator.h"
19 
20 #include <fstream>
21 #include <iostream>
22 #include <sstream>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "gandiva/bitmap_accumulator.h"
28 #include "gandiva/decimal_ir.h"
29 #include "gandiva/dex.h"
30 #include "gandiva/expr_decomposer.h"
31 #include "gandiva/expression.h"
32 #include "gandiva/lvalue.h"
33 
34 namespace gandiva {
35 
36 #define ADD_TRACE(...)     \
37   if (enable_ir_traces_) { \
38     AddTrace(__VA_ARGS__); \
39   }
40 
LLVMGenerator()41 LLVMGenerator::LLVMGenerator() : enable_ir_traces_(false) {}
42 
Make(std::shared_ptr<Configuration> config,std::unique_ptr<LLVMGenerator> * llvm_generator)43 Status LLVMGenerator::Make(std::shared_ptr<Configuration> config,
44                            std::unique_ptr<LLVMGenerator>* llvm_generator) {
45   std::unique_ptr<LLVMGenerator> llvmgen_obj(new LLVMGenerator());
46 
47   ARROW_RETURN_NOT_OK(Engine::Make(config, &(llvmgen_obj->engine_)));
48   *llvm_generator = std::move(llvmgen_obj);
49 
50   return Status::OK();
51 }
52 
Add(const ExpressionPtr expr,const FieldDescriptorPtr output)53 Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) {
54   int idx = static_cast<int>(compiled_exprs_.size());
55   // decompose the expression to separate out value and validities.
56   ExprDecomposer decomposer(function_registry_, annotator_);
57   ValueValidityPairPtr value_validity;
58   ARROW_RETURN_NOT_OK(decomposer.Decompose(*expr->root(), &value_validity));
59   // Generate the IR function for the decomposed expression.
60   std::unique_ptr<CompiledExpr> compiled_expr(new CompiledExpr(value_validity, output));
61   llvm::Function* ir_function = nullptr;
62   ARROW_RETURN_NOT_OK(CodeGenExprValue(value_validity->value_expr(),
63                                        annotator_.buffer_count(), output, idx,
64                                        &ir_function, selection_vector_mode_));
65   compiled_expr->SetIRFunction(selection_vector_mode_, ir_function);
66 
67   compiled_exprs_.push_back(std::move(compiled_expr));
68   return Status::OK();
69 }
70 
71 /// Build and optimise module for projection expression.
Build(const ExpressionVector & exprs,SelectionVector::Mode mode)72 Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode mode) {
73   selection_vector_mode_ = mode;
74   for (auto& expr : exprs) {
75     auto output = annotator_.AddOutputFieldDescriptor(expr->result());
76     ARROW_RETURN_NOT_OK(Add(expr, output));
77   }
78 
79   // Compile and inject into the process' memory the generated function.
80   ARROW_RETURN_NOT_OK(engine_->FinalizeModule());
81 
82   // setup the jit functions for each expression.
83   for (auto& compiled_expr : compiled_exprs_) {
84     auto ir_fn = compiled_expr->GetIRFunction(mode);
85     auto jit_fn = reinterpret_cast<EvalFunc>(engine_->CompiledFunction(ir_fn));
86     compiled_expr->SetJITFunction(selection_vector_mode_, jit_fn);
87   }
88 
89   return Status::OK();
90 }
91 
92 /// Execute the compiled module against the provided vectors.
Execute(const arrow::RecordBatch & record_batch,const ArrayDataVector & output_vector)93 Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
94                               const ArrayDataVector& output_vector) {
95   return Execute(record_batch, nullptr, output_vector);
96 }
97 
98 /// Execute the compiled module against the provided vectors based on the type of
99 /// selection vector.
Execute(const arrow::RecordBatch & record_batch,const SelectionVector * selection_vector,const ArrayDataVector & output_vector)100 Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
101                               const SelectionVector* selection_vector,
102                               const ArrayDataVector& output_vector) {
103   DCHECK_GT(record_batch.num_rows(), 0);
104 
105   auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector);
106   DCHECK_GT(eval_batch->GetNumBuffers(), 0);
107 
108   auto mode = SelectionVector::MODE_NONE;
109   if (selection_vector != nullptr) {
110     mode = selection_vector->GetMode();
111   }
112   if (mode != selection_vector_mode_) {
113     return Status::Invalid("llvm expression built for selection vector mode ",
114                            selection_vector_mode_, " received vector with mode ", mode);
115   }
116 
117   for (auto& compiled_expr : compiled_exprs_) {
118     // generate data/offset vectors.
119     const uint8_t* selection_buffer = nullptr;
120     auto num_output_rows = record_batch.num_rows();
121     if (selection_vector != nullptr) {
122       selection_buffer = selection_vector->GetBuffer().data();
123       num_output_rows = selection_vector->GetNumSlots();
124     }
125 
126     EvalFunc jit_function = compiled_expr->GetJITFunction(mode);
127     jit_function(eval_batch->GetBufferArray(), eval_batch->GetBufferOffsetArray(),
128                  eval_batch->GetLocalBitMapArray(), selection_buffer,
129                  (int64_t)eval_batch->GetExecutionContext(), num_output_rows);
130 
131     // check for execution errors
132     ARROW_RETURN_IF(
133         eval_batch->GetExecutionContext()->has_error(),
134         Status::ExecutionError(eval_batch->GetExecutionContext()->get_error()));
135 
136     // generate validity vectors.
137     ComputeBitMapsForExpr(*compiled_expr, *eval_batch, selection_vector);
138   }
139 
140   return Status::OK();
141 }
142 
LoadVectorAtIndex(llvm::Value * arg_addrs,int idx,const std::string & name)143 llvm::Value* LLVMGenerator::LoadVectorAtIndex(llvm::Value* arg_addrs, int idx,
144                                               const std::string& name) {
145   auto* idx_val = types()->i32_constant(idx);
146   auto* offset = ir_builder()->CreateGEP(arg_addrs, idx_val, name + "_mem_addr");
147   return ir_builder()->CreateLoad(offset, name + "_mem");
148 }
149 
150 /// Get reference to validity array at specified index in the args list.
GetValidityReference(llvm::Value * arg_addrs,int idx,FieldPtr field)151 llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx,
152                                                  FieldPtr field) {
153   const std::string& name = field->name();
154   llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
155   return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), name + "_varray");
156 }
157 
158 /// Get reference to data array at specified index in the args list.
GetDataBufferPtrReference(llvm::Value * arg_addrs,int idx,FieldPtr field)159 llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx,
160                                                       FieldPtr field) {
161   const std::string& name = field->name();
162   llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
163   return ir_builder()->CreateIntToPtr(load, types()->i8_ptr_type(), name + "_buf_ptr");
164 }
165 
166 /// Get reference to data array at specified index in the args list.
GetDataReference(llvm::Value * arg_addrs,int idx,FieldPtr field)167 llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx,
168                                              FieldPtr field) {
169   const std::string& name = field->name();
170   llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
171   llvm::Type* base_type = types()->DataVecType(field->type());
172   llvm::Value* ret;
173   if (base_type->isPointerTy()) {
174     ret = ir_builder()->CreateIntToPtr(load, base_type, name + "_darray");
175   } else {
176     llvm::Type* pointer_type = types()->ptr_type(base_type);
177     ret = ir_builder()->CreateIntToPtr(load, pointer_type, name + "_darray");
178   }
179   return ret;
180 }
181 
182 /// Get reference to offsets array at specified index in the args list.
GetOffsetsReference(llvm::Value * arg_addrs,int idx,FieldPtr field)183 llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx,
184                                                 FieldPtr field) {
185   const std::string& name = field->name();
186   llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
187   return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray");
188 }
189 
190 /// Get reference to local bitmap array at specified index in the args list.
GetLocalBitMapReference(llvm::Value * arg_bitmaps,int idx)191 llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) {
192   llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, idx, "");
193   return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(),
194                                       std::to_string(idx) + "_lbmap");
195 }
196 
197 /// \brief Generate code for one expression.
198 
199 // Sample IR code for "c1:int + c2:int"
200 //
201 // The C-code equivalent is :
202 // ------------------------------
203 // int expr_0(int64_t *addrs, int64_t *local_bitmaps,
204 //            int64_t execution_context_ptr, int64_t nrecords) {
205 //   int *outVec = (int *) addrs[5];
206 //   int *c0Vec = (int *) addrs[1];
207 //   int *c1Vec = (int *) addrs[3];
208 //   for (int loop_var = 0; loop_var < nrecords; ++loop_var) {
209 //     int c0 = c0Vec[loop_var];
210 //     int c1 = c1Vec[loop_var];
211 //     int out = c0 + c1;
212 //     outVec[loop_var] = out;
213 //   }
214 // }
215 //
216 // IR Code
217 // --------
218 //
219 // define i32 @expr_0(i64* %args, i64* %local_bitmaps, i64 %execution_context_ptr, , i64
220 // %nrecords) { entry:
221 //   %outmemAddr = getelementptr i64, i64* %args, i32 5
222 //   %outmem = load i64, i64* %outmemAddr
223 //   %outVec = inttoptr i64 %outmem to i32*
224 //   %c0memAddr = getelementptr i64, i64* %args, i32 1
225 //   %c0mem = load i64, i64* %c0memAddr
226 //   %c0Vec = inttoptr i64 %c0mem to i32*
227 //   %c1memAddr = getelementptr i64, i64* %args, i32 3
228 //   %c1mem = load i64, i64* %c1memAddr
229 //   %c1Vec = inttoptr i64 %c1mem to i32*
230 //   br label %loop
231 // loop:                                             ; preds = %loop, %entry
232 //   %loop_var = phi i64 [ 0, %entry ], [ %"loop_var+1", %loop ]
233 //   %"loop_var+1" = add i64 %loop_var, 1
234 //   %0 = getelementptr i32, i32* %c0Vec, i32 %loop_var
235 //   %c0 = load i32, i32* %0
236 //   %1 = getelementptr i32, i32* %c1Vec, i32 %loop_var
237 //   %c1 = load i32, i32* %1
238 //   %add_int_int = call i32 @add_int_int(i32 %c0, i32 %c1)
239 //   %2 = getelementptr i32, i32* %outVec, i32 %loop_var
240 //   store i32 %add_int_int, i32* %2
241 //   %"loop_var < nrec" = icmp slt i64 %"loop_var+1", %nrecords
242 //   br i1 %"loop_var < nrec", label %loop, label %exit
243 // exit:                                             ; preds = %loop
244 //   ret i32 0
245 // }
CodeGenExprValue(DexPtr value_expr,int buffer_count,FieldDescriptorPtr output,int suffix_idx,llvm::Function ** fn,SelectionVector::Mode selection_vector_mode)246 Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
247                                        FieldDescriptorPtr output, int suffix_idx,
248                                        llvm::Function** fn,
249                                        SelectionVector::Mode selection_vector_mode) {
250   llvm::IRBuilder<>* builder = ir_builder();
251   // Create fn prototype :
252   //   int expr_1 (long **addrs, long *offsets, long **bitmaps,
253   //               long *context_ptr, long nrec)
254   std::vector<llvm::Type*> arguments;
255   arguments.push_back(types()->i64_ptr_type());  // addrs
256   arguments.push_back(types()->i64_ptr_type());  // offsets
257   arguments.push_back(types()->i64_ptr_type());  // bitmaps
258   switch (selection_vector_mode) {
259     case SelectionVector::MODE_NONE:
260     case SelectionVector::MODE_UINT16:
261       arguments.push_back(types()->ptr_type(types()->i16_type()));
262       break;
263     case SelectionVector::MODE_UINT32:
264       arguments.push_back(types()->i32_ptr_type());
265       break;
266     case SelectionVector::MODE_UINT64:
267       arguments.push_back(types()->i64_ptr_type());
268   }
269   arguments.push_back(types()->i64_type());  // ctx_ptr
270   arguments.push_back(types()->i64_type());  // nrec
271   llvm::FunctionType* prototype =
272       llvm::FunctionType::get(types()->i32_type(), arguments, false /*isVarArg*/);
273 
274   // Create fn
275   std::string func_name = "expr_" + std::to_string(suffix_idx) + "_" +
276                           std::to_string(static_cast<int>(selection_vector_mode));
277   engine_->AddFunctionToCompile(func_name);
278   *fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, func_name,
279                                module());
280   ARROW_RETURN_IF((*fn == nullptr), Status::CodeGenError("Error creating function."));
281 
282   // Name the arguments
283   llvm::Function::arg_iterator args = (*fn)->arg_begin();
284   llvm::Value* arg_addrs = &*args;
285   arg_addrs->setName("inputs_addr");
286   ++args;
287   llvm::Value* arg_addr_offsets = &*args;
288   arg_addr_offsets->setName("inputs_addr_offsets");
289   ++args;
290   llvm::Value* arg_local_bitmaps = &*args;
291   arg_local_bitmaps->setName("local_bitmaps");
292   ++args;
293   llvm::Value* arg_selection_vector = &*args;
294   arg_selection_vector->setName("selection_vector");
295   ++args;
296   llvm::Value* arg_context_ptr = &*args;
297   arg_context_ptr->setName("context_ptr");
298   ++args;
299   llvm::Value* arg_nrecords = &*args;
300   arg_nrecords->setName("nrecords");
301 
302   llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context(), "entry", *fn);
303   llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context(), "loop", *fn);
304   llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context(), "exit", *fn);
305 
306   // Add reference to output vector (in entry block)
307   builder->SetInsertPoint(loop_entry);
308   llvm::Value* output_ref =
309       GetDataReference(arg_addrs, output->data_idx(), output->field());
310   llvm::Value* output_buffer_ptr_ref = GetDataBufferPtrReference(
311       arg_addrs, output->data_buffer_ptr_idx(), output->field());
312   llvm::Value* output_offset_ref =
313       GetOffsetsReference(arg_addrs, output->offsets_idx(), output->field());
314 
315   std::vector<llvm::Value*> slice_offsets;
316   for (int idx = 0; idx < buffer_count; idx++) {
317     auto offsetAddr = builder->CreateGEP(arg_addr_offsets, types()->i32_constant(idx));
318     auto offset = builder->CreateLoad(offsetAddr);
319     slice_offsets.push_back(offset);
320   }
321 
322   // Loop body
323   builder->SetInsertPoint(loop_body);
324 
325   // define loop_var : start with 0, +1 after each iter
326   llvm::PHINode* loop_var = builder->CreatePHI(types()->i64_type(), 2, "loop_var");
327 
328   llvm::Value* position_var = loop_var;
329   if (selection_vector_mode != SelectionVector::MODE_NONE) {
330     position_var = builder->CreateIntCast(
331         builder->CreateLoad(builder->CreateGEP(arg_selection_vector, loop_var),
332                             "uncasted_position_var"),
333         types()->i64_type(), true, "position_var");
334   }
335 
336   // The visitor can add code to both the entry/loop blocks.
337   Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, slice_offsets,
338                   arg_context_ptr, position_var);
339   value_expr->Accept(visitor);
340   LValuePtr output_value = visitor.result();
341 
342   // The "current" block may have changed due to code generation in the visitor.
343   llvm::BasicBlock* loop_body_tail = builder->GetInsertBlock();
344 
345   // add jump to "loop block" at the end of the "setup block".
346   builder->SetInsertPoint(loop_entry);
347   builder->CreateBr(loop_body);
348 
349   // save the value in the output vector.
350   builder->SetInsertPoint(loop_body_tail);
351 
352   auto output_type_id = output->Type()->id();
353   if (output_type_id == arrow::Type::BOOL) {
354     SetPackedBitValue(output_ref, loop_var, output_value->data());
355   } else if (arrow::is_primitive(output_type_id) ||
356              output_type_id == arrow::Type::DECIMAL) {
357     llvm::Value* slot_offset = builder->CreateGEP(output_ref, loop_var);
358     builder->CreateStore(output_value->data(), slot_offset);
359   } else if (arrow::is_binary_like(output_type_id)) {
360     // Var-len output. Make a function call to populate the data.
361     // if there is an error, the fn sets it in the context. And, will be returned at the
362     // end of this row batch.
363     AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(),
364                     {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var,
365                      output_value->data(), output_value->length()});
366   } else {
367     return Status::NotImplemented("output type ", output->Type()->ToString(),
368                                   " not supported");
369   }
370   ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data());
371 
372   if (visitor.has_arena_allocs()) {
373     // Reset allocations to avoid excessive memory usage. Once the result is copied to
374     // the output vector (store instruction above), any memory allocations in this
375     // iteration of the loop are no longer needed.
376     std::vector<llvm::Value*> reset_args;
377     reset_args.push_back(arg_context_ptr);
378     AddFunctionCall("gdv_fn_context_arena_reset", types()->void_type(), reset_args);
379   }
380 
381   // check loop_var
382   loop_var->addIncoming(types()->i64_constant(0), loop_entry);
383   llvm::Value* loop_update =
384       builder->CreateAdd(loop_var, types()->i64_constant(1), "loop_var+1");
385   loop_var->addIncoming(loop_update, loop_body_tail);
386 
387   llvm::Value* loop_var_check =
388       builder->CreateICmpSLT(loop_update, arg_nrecords, "loop_var < nrec");
389   builder->CreateCondBr(loop_var_check, loop_body, loop_exit);
390 
391   // Loop exit
392   builder->SetInsertPoint(loop_exit);
393   builder->CreateRet(types()->i32_constant(0));
394   return Status::OK();
395 }
396 
397 /// Return value of a bit in bitMap.
GetPackedBitValue(llvm::Value * bitmap,llvm::Value * position)398 llvm::Value* LLVMGenerator::GetPackedBitValue(llvm::Value* bitmap,
399                                               llvm::Value* position) {
400   ADD_TRACE("fetch bit at position %T", position);
401 
402   llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
403       bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
404   return AddFunctionCall("bitMapGetBit", types()->i1_type(), {bitmap8, position});
405 }
406 
407 /// Set the value of a bit in bitMap.
SetPackedBitValue(llvm::Value * bitmap,llvm::Value * position,llvm::Value * value)408 void LLVMGenerator::SetPackedBitValue(llvm::Value* bitmap, llvm::Value* position,
409                                       llvm::Value* value) {
410   ADD_TRACE("set bit at position %T", position);
411   ADD_TRACE("  to value %T ", value);
412 
413   llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
414       bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
415   AddFunctionCall("bitMapSetBit", types()->void_type(), {bitmap8, position, value});
416 }
417 
418 /// Return value of a bit in validity bitMap (handles null bitmaps too).
GetPackedValidityBitValue(llvm::Value * bitmap,llvm::Value * position)419 llvm::Value* LLVMGenerator::GetPackedValidityBitValue(llvm::Value* bitmap,
420                                                       llvm::Value* position) {
421   ADD_TRACE("fetch validity bit at position %T", position);
422 
423   llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
424       bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
425   return AddFunctionCall("bitMapValidityGetBit", types()->i1_type(), {bitmap8, position});
426 }
427 
428 /// Clear the bit in bitMap if value = false.
ClearPackedBitValueIfFalse(llvm::Value * bitmap,llvm::Value * position,llvm::Value * value)429 void LLVMGenerator::ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position,
430                                                llvm::Value* value) {
431   ADD_TRACE("ClearIfFalse bit at position %T", position);
432   ADD_TRACE("   value %T ", value);
433 
434   llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
435       bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
436   AddFunctionCall("bitMapClearBitIfFalse", types()->void_type(),
437                   {bitmap8, position, value});
438 }
439 
440 /// Extract the bitmap addresses, and do an intersection.
ComputeBitMapsForExpr(const CompiledExpr & compiled_expr,const EvalBatch & eval_batch,const SelectionVector * selection_vector)441 void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr,
442                                           const EvalBatch& eval_batch,
443                                           const SelectionVector* selection_vector) {
444   auto validities = compiled_expr.value_validity()->validity_exprs();
445 
446   // Extract all the source bitmap addresses.
447   BitMapAccumulator accumulator(eval_batch);
448   for (auto& validity_dex : validities) {
449     validity_dex->Accept(accumulator);
450   }
451 
452   // Extract the destination bitmap address.
453   int out_idx = compiled_expr.output()->validity_idx();
454   uint8_t* dst_bitmap = eval_batch.GetBuffer(out_idx);
455   // Compute the destination bitmap.
456   if (selection_vector == nullptr) {
457     accumulator.ComputeResult(dst_bitmap);
458   } else {
459     /// The output bitmap is an intersection of some input/local bitmaps. However, with a
460     /// selection vector, only the bits corresponding to the indices in the selection
461     /// vector need to set in the output bitmap. This is done in two steps :
462     ///
463     /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap.
464     /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap.
465     LocalBitMapsHolder bit_map_holder(eval_batch.num_records(), 1);
466     uint8_t* temp_bitmap = bit_map_holder.GetLocalBitMap(0);
467     accumulator.ComputeResult(temp_bitmap);
468 
469     auto num_out_records = selection_vector->GetNumSlots();
470     // the memset isn't required, doing it just for valgrind.
471     memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_out_records));
472     for (auto i = 0; i < num_out_records; ++i) {
473       auto bit = arrow::BitUtil::GetBit(temp_bitmap, selection_vector->GetIndex(i));
474       arrow::BitUtil::SetBitTo(dst_bitmap, i, bit);
475     }
476   }
477 }
478 
AddFunctionCall(const std::string & full_name,llvm::Type * ret_type,const std::vector<llvm::Value * > & args)479 llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,
480                                             llvm::Type* ret_type,
481                                             const std::vector<llvm::Value*>& args) {
482   // find the llvm function.
483   llvm::Function* fn = module()->getFunction(full_name);
484   DCHECK_NE(fn, nullptr) << "missing function " << full_name;
485 
486   if (enable_ir_traces_ && !full_name.compare("printf") &&
487       !full_name.compare("printff")) {
488     // Trace for debugging
489     ADD_TRACE("invoke native fn " + full_name);
490   }
491 
492   // build a call to the llvm function.
493   llvm::Value* value;
494   if (ret_type->isVoidTy()) {
495     // void functions can't have a name for the call.
496     value = ir_builder()->CreateCall(fn, args);
497   } else {
498     value = ir_builder()->CreateCall(fn, args, full_name);
499     DCHECK(value->getType() == ret_type);
500   }
501 
502   return value;
503 }
504 
BuildDecimalLValue(llvm::Value * value,DataTypePtr arrow_type)505 std::shared_ptr<DecimalLValue> LLVMGenerator::BuildDecimalLValue(llvm::Value* value,
506                                                                  DataTypePtr arrow_type) {
507   // only decimals of size 128-bit supported.
508   DCHECK(is_decimal_128(arrow_type));
509   auto decimal_type =
510       arrow::internal::checked_cast<arrow::DecimalType*>(arrow_type.get());
511   return std::make_shared<DecimalLValue>(value, nullptr,
512                                          types()->i32_constant(decimal_type->precision()),
513                                          types()->i32_constant(decimal_type->scale()));
514 }
515 
516 #define ADD_VISITOR_TRACE(...)         \
517   if (generator_->enable_ir_traces_) { \
518     generator_->AddTrace(__VA_ARGS__); \
519   }
520 
521 // Visitor for generating the code for a decomposed expression.
Visitor(LLVMGenerator * generator,llvm::Function * function,llvm::BasicBlock * entry_block,llvm::Value * arg_addrs,llvm::Value * arg_local_bitmaps,std::vector<llvm::Value * > slice_offsets,llvm::Value * arg_context_ptr,llvm::Value * loop_var)522 LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* function,
523                                 llvm::BasicBlock* entry_block, llvm::Value* arg_addrs,
524                                 llvm::Value* arg_local_bitmaps,
525                                 std::vector<llvm::Value*> slice_offsets,
526                                 llvm::Value* arg_context_ptr, llvm::Value* loop_var)
527     : generator_(generator),
528       function_(function),
529       entry_block_(entry_block),
530       arg_addrs_(arg_addrs),
531       arg_local_bitmaps_(arg_local_bitmaps),
532       slice_offsets_(slice_offsets),
533       arg_context_ptr_(arg_context_ptr),
534       loop_var_(loop_var),
535       has_arena_allocs_(false) {
536   ADD_VISITOR_TRACE("Iteration %T", loop_var);
537 }
538 
Visit(const VectorReadFixedLenValueDex & dex)539 void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) {
540   llvm::IRBuilder<>* builder = ir_builder();
541   llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
542   llvm::Value* slot_index = builder->CreateAdd(loop_var_, GetSliceOffset(dex.DataIdx()));
543   llvm::Value* slot_value;
544   std::shared_ptr<LValue> lvalue;
545 
546   switch (dex.FieldType()->id()) {
547     case arrow::Type::BOOL:
548       slot_value = generator_->GetPackedBitValue(slot_ref, slot_index);
549       lvalue = std::make_shared<LValue>(slot_value);
550       break;
551 
552     case arrow::Type::DECIMAL: {
553       auto slot_offset = builder->CreateGEP(slot_ref, slot_index);
554       slot_value = builder->CreateLoad(slot_offset, dex.FieldName());
555       lvalue = generator_->BuildDecimalLValue(slot_value, dex.FieldType());
556       break;
557     }
558 
559     default: {
560       auto slot_offset = builder->CreateGEP(slot_ref, slot_index);
561       slot_value = builder->CreateLoad(slot_offset, dex.FieldName());
562       lvalue = std::make_shared<LValue>(slot_value);
563       break;
564     }
565   }
566   ADD_VISITOR_TRACE("visit fixed-len data vector " + dex.FieldName() + " value %T",
567                     slot_value);
568   result_ = lvalue;
569 }
570 
Visit(const VectorReadVarLenValueDex & dex)571 void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) {
572   llvm::IRBuilder<>* builder = ir_builder();
573   llvm::Value* slot;
574 
575   // compute len from the offsets array.
576   llvm::Value* offsets_slot_ref =
577       GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field());
578   llvm::Value* offsets_slot_index =
579       builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx()));
580 
581   // => offset_start = offsets[loop_var]
582   slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index);
583   llvm::Value* offset_start = builder->CreateLoad(slot, "offset_start");
584 
585   // => offset_end = offsets[loop_var + 1]
586   llvm::Value* offsets_slot_index_next = builder->CreateAdd(
587       offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1");
588   slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index_next);
589   llvm::Value* offset_end = builder->CreateLoad(slot, "offset_end");
590 
591   // => len_value = offset_end - offset_start
592   llvm::Value* len_value =
593       builder->CreateSub(offset_end, offset_start, dex.FieldName() + "Len");
594 
595   // get the data from the data array, at offset 'offset_start'.
596   llvm::Value* data_slot_ref =
597       GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
598   llvm::Value* data_value = builder->CreateGEP(data_slot_ref, offset_start);
599   ADD_VISITOR_TRACE("visit var-len data vector " + dex.FieldName() + " len %T",
600                     len_value);
601   result_.reset(new LValue(data_value, len_value));
602 }
603 
Visit(const VectorReadValidityDex & dex)604 void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) {
605   llvm::IRBuilder<>* builder = ir_builder();
606   llvm::Value* slot_ref =
607       GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field());
608   llvm::Value* slot_index =
609       builder->CreateAdd(loop_var_, GetSliceOffset(dex.ValidityIdx()));
610   llvm::Value* validity = generator_->GetPackedValidityBitValue(slot_ref, slot_index);
611 
612   ADD_VISITOR_TRACE("visit validity vector " + dex.FieldName() + " value %T", validity);
613   result_.reset(new LValue(validity));
614 }
615 
Visit(const LocalBitMapValidityDex & dex)616 void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) {
617   llvm::Value* slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx());
618   llvm::Value* validity = generator_->GetPackedBitValue(slot_ref, loop_var_);
619 
620   ADD_VISITOR_TRACE(
621       "visit local bitmap " + std::to_string(dex.local_bitmap_idx()) + " value %T",
622       validity);
623   result_.reset(new LValue(validity));
624 }
625 
Visit(const TrueDex & dex)626 void LLVMGenerator::Visitor::Visit(const TrueDex& dex) {
627   result_.reset(new LValue(generator_->types()->true_constant()));
628 }
629 
Visit(const FalseDex & dex)630 void LLVMGenerator::Visitor::Visit(const FalseDex& dex) {
631   result_.reset(new LValue(generator_->types()->false_constant()));
632 }
633 
Visit(const LiteralDex & dex)634 void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) {
635   LLVMTypes* types = generator_->types();
636   llvm::Value* value = nullptr;
637   llvm::Value* len = nullptr;
638 
639   switch (dex.type()->id()) {
640     case arrow::Type::BOOL:
641       value = types->i1_constant(arrow::util::get<bool>(dex.holder()));
642       break;
643 
644     case arrow::Type::UINT8:
645       value = types->i8_constant(arrow::util::get<uint8_t>(dex.holder()));
646       break;
647 
648     case arrow::Type::UINT16:
649       value = types->i16_constant(arrow::util::get<uint16_t>(dex.holder()));
650       break;
651 
652     case arrow::Type::UINT32:
653       value = types->i32_constant(arrow::util::get<uint32_t>(dex.holder()));
654       break;
655 
656     case arrow::Type::UINT64:
657       value = types->i64_constant(arrow::util::get<uint64_t>(dex.holder()));
658       break;
659 
660     case arrow::Type::INT8:
661       value = types->i8_constant(arrow::util::get<int8_t>(dex.holder()));
662       break;
663 
664     case arrow::Type::INT16:
665       value = types->i16_constant(arrow::util::get<int16_t>(dex.holder()));
666       break;
667 
668     case arrow::Type::FLOAT:
669       value = types->float_constant(arrow::util::get<float>(dex.holder()));
670       break;
671 
672     case arrow::Type::DOUBLE:
673       value = types->double_constant(arrow::util::get<double>(dex.holder()));
674       break;
675 
676     case arrow::Type::STRING:
677     case arrow::Type::BINARY: {
678       const std::string& str = arrow::util::get<std::string>(dex.holder());
679 
680       llvm::Constant* str_int_cast = types->i64_constant((int64_t)str.c_str());
681       value = llvm::ConstantExpr::getIntToPtr(str_int_cast, types->i8_ptr_type());
682       len = types->i32_constant(static_cast<int32_t>(str.length()));
683       break;
684     }
685 
686     case arrow::Type::INT32:
687     case arrow::Type::DATE32:
688     case arrow::Type::TIME32:
689     case arrow::Type::INTERVAL_MONTHS:
690       value = types->i32_constant(arrow::util::get<int32_t>(dex.holder()));
691       break;
692 
693     case arrow::Type::INT64:
694     case arrow::Type::DATE64:
695     case arrow::Type::TIME64:
696     case arrow::Type::TIMESTAMP:
697     case arrow::Type::INTERVAL_DAY_TIME:
698       value = types->i64_constant(arrow::util::get<int64_t>(dex.holder()));
699       break;
700 
701     case arrow::Type::DECIMAL: {
702       // build code for struct
703       auto scalar = arrow::util::get<DecimalScalar128>(dex.holder());
704       // ConstantInt doesn't have a get method that takes int128 or a pair of int64. so,
705       // passing the string representation instead.
706       auto int128_value =
707           llvm::ConstantInt::get(llvm::Type::getInt128Ty(*generator_->context()),
708                                  Decimal128(scalar.value()).ToIntegerString(), 10);
709       auto type = arrow::decimal(scalar.precision(), scalar.scale());
710       auto lvalue = generator_->BuildDecimalLValue(int128_value, type);
711       // set it as the l-value and return.
712       result_ = lvalue;
713       return;
714     }
715 
716     default:
717       DCHECK(0);
718   }
719   ADD_VISITOR_TRACE("visit Literal %T", value);
720   result_.reset(new LValue(value, len));
721 }
722 
Visit(const NonNullableFuncDex & dex)723 void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
724   const std::string& function_name = dex.func_descriptor()->name();
725   ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name);
726 
727   const NativeFunction* native_function = dex.native_function();
728 
729   // build the function params (ignore validity).
730   auto params = BuildParams(dex.function_holder().get(), dex.args(), false,
731                             native_function->NeedsContext());
732 
733   auto arrow_return_type = dex.func_descriptor()->return_type();
734   if (native_function->CanReturnErrors()) {
735     // slow path : if a function can return errors, skip invoking the function
736     // unless all of the input args are valid. Otherwise, it can cause spurious errors.
737 
738     llvm::IRBuilder<>* builder = ir_builder();
739     LLVMTypes* types = generator_->types();
740     auto arrow_type_id = arrow_return_type->id();
741     auto result_type = types->IRType(arrow_type_id);
742 
743     // Build combined validity of the args.
744     llvm::Value* is_valid = types->true_constant();
745     for (auto& pair : dex.args()) {
746       auto arg_validity = BuildCombinedValidity(pair->validity_exprs());
747       is_valid = builder->CreateAnd(is_valid, arg_validity, "validityBitAnd");
748     }
749 
750     // then block
751     auto then_lambda = [&] {
752       ADD_VISITOR_TRACE("fn " + function_name +
753                         " can return errors : all args valid, invoke fn");
754       return BuildFunctionCall(native_function, arrow_return_type, &params);
755     };
756 
757     // else block
758     auto else_lambda = [&] {
759       ADD_VISITOR_TRACE("fn " + function_name +
760                         " can return errors : not all args valid, return dummy value");
761       llvm::Value* else_value = types->NullConstant(result_type);
762       llvm::Value* else_value_len = nullptr;
763       if (arrow::is_binary_like(arrow_type_id)) {
764         else_value_len = types->i32_constant(0);
765       }
766       return std::make_shared<LValue>(else_value, else_value_len);
767     };
768 
769     result_ = BuildIfElse(is_valid, then_lambda, else_lambda, arrow_return_type);
770   } else {
771     // fast path : invoke function without computing validities.
772     result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
773   }
774 }
775 
Visit(const NullableNeverFuncDex & dex)776 void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) {
777   ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name());
778   const NativeFunction* native_function = dex.native_function();
779 
780   // build function params along with validity.
781   auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
782                             native_function->NeedsContext());
783 
784   auto arrow_return_type = dex.func_descriptor()->return_type();
785   result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
786 }
787 
Visit(const NullableInternalFuncDex & dex)788 void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
789   ADD_VISITOR_TRACE("visit NullableInternal base function " +
790                     dex.func_descriptor()->name());
791   llvm::IRBuilder<>* builder = ir_builder();
792   LLVMTypes* types = generator_->types();
793 
794   const NativeFunction* native_function = dex.native_function();
795 
796   // build function params along with validity.
797   auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
798                             native_function->NeedsContext());
799 
800   // add an extra arg for validity (allocated on stack).
801   llvm::AllocaInst* result_valid_ptr =
802       new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_);
803   params.push_back(result_valid_ptr);
804 
805   auto arrow_return_type = dex.func_descriptor()->return_type();
806   result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
807 
808   // load the result validity and truncate to i1.
809   llvm::Value* result_valid_i8 = builder->CreateLoad(result_valid_ptr);
810   llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type());
811 
812   // set validity bit in the local bitmap.
813   ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid);
814 }
815 
Visit(const IfDex & dex)816 void LLVMGenerator::Visitor::Visit(const IfDex& dex) {
817   ADD_VISITOR_TRACE("visit IfExpression");
818   llvm::IRBuilder<>* builder = ir_builder();
819 
820   // Evaluate condition.
821   LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv());
822 
823   // Check if the result is valid, and there is match.
824   llvm::Value* validAndMatched =
825       builder->CreateAnd(if_condition->data(), if_condition->validity(), "validAndMatch");
826 
827   // then block
828   auto then_lambda = [&] {
829     ADD_VISITOR_TRACE("branch to then block");
830     LValuePtr then_lvalue = BuildValueAndValidity(dex.then_vv());
831     ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), then_lvalue->validity());
832     ADD_VISITOR_TRACE("IfExpression result validity %T in matching then",
833                       then_lvalue->validity());
834     return then_lvalue;
835   };
836 
837   // else block
838   auto else_lambda = [&] {
839     LValuePtr else_lvalue;
840     if (dex.is_terminal_else()) {
841       ADD_VISITOR_TRACE("branch to terminal else block");
842 
843       else_lvalue = BuildValueAndValidity(dex.else_vv());
844       // update the local bitmap with the validity.
845       ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), else_lvalue->validity());
846       ADD_VISITOR_TRACE("IfExpression result validity %T in terminal else",
847                         else_lvalue->validity());
848     } else {
849       ADD_VISITOR_TRACE("branch to non-terminal else block");
850 
851       // this is a non-terminal else. let the child (nested if/else) handle validity.
852       auto value_expr = dex.else_vv().value_expr();
853       value_expr->Accept(*this);
854       else_lvalue = result();
855     }
856     return else_lvalue;
857   };
858 
859   // build the if-else condition.
860   result_ = BuildIfElse(validAndMatched, then_lambda, else_lambda, dex.result_type());
861   if (arrow::is_binary_like(dex.result_type()->id())) {
862     ADD_VISITOR_TRACE("IfElse result length %T", result_->length());
863   }
864   ADD_VISITOR_TRACE("IfElse result value %T", result_->data());
865 }
866 
867 // Boolean AND
868 // if any arg is valid and false,
869 //   short-circuit and return FALSE (value=false, valid=true)
870 // else if all args are valid and true
871 //   return TRUE (value=true, valid=true)
872 // else
873 //   return NULL (value=true, valid=false)
874 
Visit(const BooleanAndDex & dex)875 void LLVMGenerator::Visitor::Visit(const BooleanAndDex& dex) {
876   ADD_VISITOR_TRACE("visit BooleanAndExpression");
877   llvm::IRBuilder<>* builder = ir_builder();
878   LLVMTypes* types = generator_->types();
879   llvm::LLVMContext* context = generator_->context();
880 
881   // Create blocks for short-circuit.
882   llvm::BasicBlock* short_circuit_bb =
883       llvm::BasicBlock::Create(*context, "short_circuit", function_);
884   llvm::BasicBlock* non_short_circuit_bb =
885       llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
886   llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
887 
888   llvm::Value* all_exprs_valid = types->true_constant();
889   for (auto& pair : dex.args()) {
890     LValuePtr current = BuildValueAndValidity(*pair);
891 
892     ADD_VISITOR_TRACE("BooleanAndExpression arg value %T", current->data());
893     ADD_VISITOR_TRACE("BooleanAndExpression arg validity %T", current->validity());
894 
895     // short-circuit if valid and false
896     llvm::Value* is_false = builder->CreateNot(current->data());
897     llvm::Value* valid_and_false =
898         builder->CreateAnd(is_false, current->validity(), "valid_and_false");
899 
900     llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
901     builder->CreateCondBr(valid_and_false, short_circuit_bb, else_bb);
902 
903     // Emit the else block.
904     builder->SetInsertPoint(else_bb);
905     // remember if any nulls were encountered.
906     all_exprs_valid =
907         builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
908     // continue to evaluate the next pair in list.
909   }
910   builder->CreateBr(non_short_circuit_bb);
911 
912   // Short-circuit case (at least one of the expressions is valid and false).
913   // No need to set validity bit (valid by default).
914   builder->SetInsertPoint(short_circuit_bb);
915   ADD_VISITOR_TRACE("BooleanAndExpression result value false");
916   ADD_VISITOR_TRACE("BooleanAndExpression result validity true");
917   builder->CreateBr(merge_bb);
918 
919   // non short-circuit case (All expressions are either true or null).
920   // result valid if all of the exprs are non-null.
921   builder->SetInsertPoint(non_short_circuit_bb);
922   ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
923   ADD_VISITOR_TRACE("BooleanAndExpression result value true");
924   ADD_VISITOR_TRACE("BooleanAndExpression result validity %T", all_exprs_valid);
925   builder->CreateBr(merge_bb);
926 
927   builder->SetInsertPoint(merge_bb);
928   llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
929   result_value->addIncoming(types->false_constant(), short_circuit_bb);
930   result_value->addIncoming(types->true_constant(), non_short_circuit_bb);
931   result_.reset(new LValue(result_value));
932 }
933 
934 // Boolean OR
935 // if any arg is valid and true,
936 //   short-circuit and return TRUE (value=true, valid=true)
937 // else if all args are valid and false
938 //   return FALSE (value=false, valid=true)
939 // else
940 //   return NULL (value=false, valid=false)
941 
Visit(const BooleanOrDex & dex)942 void LLVMGenerator::Visitor::Visit(const BooleanOrDex& dex) {
943   ADD_VISITOR_TRACE("visit BooleanOrExpression");
944   llvm::IRBuilder<>* builder = ir_builder();
945   LLVMTypes* types = generator_->types();
946   llvm::LLVMContext* context = generator_->context();
947 
948   // Create blocks for short-circuit.
949   llvm::BasicBlock* short_circuit_bb =
950       llvm::BasicBlock::Create(*context, "short_circuit", function_);
951   llvm::BasicBlock* non_short_circuit_bb =
952       llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
953   llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
954 
955   llvm::Value* all_exprs_valid = types->true_constant();
956   for (auto& pair : dex.args()) {
957     LValuePtr current = BuildValueAndValidity(*pair);
958 
959     ADD_VISITOR_TRACE("BooleanOrExpression arg value %T", current->data());
960     ADD_VISITOR_TRACE("BooleanOrExpression arg validity %T", current->validity());
961 
962     // short-circuit if valid and true.
963     llvm::Value* valid_and_true =
964         builder->CreateAnd(current->data(), current->validity(), "valid_and_true");
965 
966     llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
967     builder->CreateCondBr(valid_and_true, short_circuit_bb, else_bb);
968 
969     // Emit the else block.
970     builder->SetInsertPoint(else_bb);
971     // remember if any nulls were encountered.
972     all_exprs_valid =
973         builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
974     // continue to evaluate the next pair in list.
975   }
976   builder->CreateBr(non_short_circuit_bb);
977 
978   // Short-circuit case (at least one of the expressions is valid and true).
979   // No need to set validity bit (valid by default).
980   builder->SetInsertPoint(short_circuit_bb);
981   ADD_VISITOR_TRACE("BooleanOrExpression result value true");
982   ADD_VISITOR_TRACE("BooleanOrExpression result validity true");
983   builder->CreateBr(merge_bb);
984 
985   // non short-circuit case (All expressions are either false or null).
986   // result valid if all of the exprs are non-null.
987   builder->SetInsertPoint(non_short_circuit_bb);
988   ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
989   ADD_VISITOR_TRACE("BooleanOrExpression result value false");
990   ADD_VISITOR_TRACE("BooleanOrExpression result validity %T", all_exprs_valid);
991   builder->CreateBr(merge_bb);
992 
993   builder->SetInsertPoint(merge_bb);
994   llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
995   result_value->addIncoming(types->true_constant(), short_circuit_bb);
996   result_value->addIncoming(types->false_constant(), non_short_circuit_bb);
997   result_.reset(new LValue(result_value));
998 }
999 
Visit(const InExprDexBase<int32_t> & dex)1000 void LLVMGenerator::Visitor::Visit(const InExprDexBase<int32_t>& dex) {
1001   VisitInExpression<int32_t>(dex);
1002 }
1003 
Visit(const InExprDexBase<int64_t> & dex)1004 void LLVMGenerator::Visitor::Visit(const InExprDexBase<int64_t>& dex) {
1005   VisitInExpression<int64_t>(dex);
1006 }
1007 
Visit(const InExprDexBase<std::string> & dex)1008 void LLVMGenerator::Visitor::Visit(const InExprDexBase<std::string>& dex) {
1009   VisitInExpression<std::string>(dex);
1010 }
1011 
1012 template <typename Type>
VisitInExpression(const InExprDexBase<Type> & dex)1013 void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase<Type>& dex) {
1014   ADD_VISITOR_TRACE("visit In Expression");
1015   LLVMTypes* types = generator_->types();
1016   std::vector<llvm::Value*> params;
1017 
1018   const InExprDex<Type>& dex_instance = dynamic_cast<const InExprDex<Type>&>(dex);
1019   /* add the holder at the beginning */
1020   llvm::Constant* ptr_int_cast =
1021       types->i64_constant((int64_t)(dex_instance.in_holder().get()));
1022   params.push_back(ptr_int_cast);
1023 
1024   /* eval expr result */
1025   for (auto& pair : dex.args()) {
1026     DexPtr value_expr = pair->value_expr();
1027     value_expr->Accept(*this);
1028     LValue& result_ref = *result();
1029     params.push_back(result_ref.data());
1030 
1031     /* length if the result is a string */
1032     if (result_ref.length() != nullptr) {
1033       params.push_back(result_ref.length());
1034     }
1035 
1036     /* push the validity of eval expr result */
1037     llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
1038     params.push_back(validity_expr);
1039   }
1040 
1041   llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL);
1042 
1043   llvm::Value* value =
1044       generator_->AddFunctionCall(dex.runtime_function(), ret_type, params);
1045   result_.reset(new LValue(value));
1046 }
1047 
BuildIfElse(llvm::Value * condition,std::function<LValuePtr ()> then_func,std::function<LValuePtr ()> else_func,DataTypePtr result_type)1048 LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition,
1049                                               std::function<LValuePtr()> then_func,
1050                                               std::function<LValuePtr()> else_func,
1051                                               DataTypePtr result_type) {
1052   llvm::IRBuilder<>* builder = ir_builder();
1053   llvm::LLVMContext* context = generator_->context();
1054   LLVMTypes* types = generator_->types();
1055 
1056   // Create blocks for the then, else and merge cases.
1057   llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context, "then", function_);
1058   llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
1059   llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
1060 
1061   builder->CreateCondBr(condition, then_bb, else_bb);
1062 
1063   // Emit the then block.
1064   builder->SetInsertPoint(then_bb);
1065   LValuePtr then_lvalue = then_func();
1066   builder->CreateBr(merge_bb);
1067 
1068   // refresh then_bb for phi (could have changed due to code generation of then_vv).
1069   then_bb = builder->GetInsertBlock();
1070 
1071   // Emit the else block.
1072   builder->SetInsertPoint(else_bb);
1073   LValuePtr else_lvalue = else_func();
1074   builder->CreateBr(merge_bb);
1075 
1076   // refresh else_bb for phi (could have changed due to code generation of else_vv).
1077   else_bb = builder->GetInsertBlock();
1078 
1079   // Emit the merge block.
1080   builder->SetInsertPoint(merge_bb);
1081   auto llvm_type = types->IRType(result_type->id());
1082   llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value");
1083   result_value->addIncoming(then_lvalue->data(), then_bb);
1084   result_value->addIncoming(else_lvalue->data(), else_bb);
1085 
1086   LValuePtr ret;
1087   switch (result_type->id()) {
1088     case arrow::Type::STRING: {
1089       llvm::PHINode* result_length;
1090       result_length = builder->CreatePHI(types->i32_type(), 2, "res_length");
1091       result_length->addIncoming(then_lvalue->length(), then_bb);
1092       result_length->addIncoming(else_lvalue->length(), else_bb);
1093       ret = std::make_shared<LValue>(result_value, result_length);
1094       break;
1095     }
1096 
1097     case arrow::Type::DECIMAL:
1098       ret = generator_->BuildDecimalLValue(result_value, result_type);
1099       break;
1100 
1101     default:
1102       ret = std::make_shared<LValue>(result_value);
1103       break;
1104   }
1105   return ret;
1106 }
1107 
BuildValueAndValidity(const ValueValidityPair & pair)1108 LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair& pair) {
1109   // generate code for value
1110   auto value_expr = pair.value_expr();
1111   value_expr->Accept(*this);
1112   auto value = result()->data();
1113   auto length = result()->length();
1114 
1115   // generate code for validity
1116   auto validity = BuildCombinedValidity(pair.validity_exprs());
1117 
1118   return std::make_shared<LValue>(value, length, validity);
1119 }
1120 
BuildFunctionCall(const NativeFunction * func,DataTypePtr arrow_return_type,std::vector<llvm::Value * > * params)1121 LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
1122                                                     DataTypePtr arrow_return_type,
1123                                                     std::vector<llvm::Value*>* params) {
1124   auto types = generator_->types();
1125   auto arrow_return_type_id = arrow_return_type->id();
1126   auto llvm_return_type = types->IRType(arrow_return_type_id);
1127   DecimalIR decimalIR(generator_->engine_.get());
1128 
1129   if (arrow_return_type_id == arrow::Type::DECIMAL) {
1130     // For decimal fns, the output precision/scale are passed along as parameters.
1131     //
1132     // convert from this :
1133     //     out = add_decimal(v1, p1, s1, v2, p2, s2)
1134     // to:
1135     //     out = add_decimal(v1, p1, s1, v2, p2, s2, out_p, out_s)
1136 
1137     // Append the out_precision and out_scale
1138     auto ret_lvalue = generator_->BuildDecimalLValue(nullptr, arrow_return_type);
1139     params->push_back(ret_lvalue->precision());
1140     params->push_back(ret_lvalue->scale());
1141 
1142     // Make the function call
1143     auto out = decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params);
1144     ret_lvalue->set_data(out);
1145     return std::move(ret_lvalue);
1146   } else {
1147     bool isDecimalFunction = false;
1148     for (auto& arg : *params) {
1149       if (arg->getType() == types->i128_type()) {
1150         isDecimalFunction = true;
1151       }
1152     }
1153     // add extra arg for return length for variable len return types (allocated on stack).
1154     llvm::AllocaInst* result_len_ptr = nullptr;
1155     if (arrow::is_binary_like(arrow_return_type_id)) {
1156       result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
1157                                             "result_len", entry_block_);
1158       params->push_back(result_len_ptr);
1159       has_arena_allocs_ = true;
1160     }
1161 
1162     // Make the function call
1163     llvm::IRBuilder<>* builder = ir_builder();
1164     auto value =
1165         isDecimalFunction
1166             ? decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params)
1167             : generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
1168     auto value_len =
1169         (result_len_ptr == nullptr) ? nullptr : builder->CreateLoad(result_len_ptr);
1170     return std::make_shared<LValue>(value, value_len);
1171   }
1172 }
1173 
BuildParams(FunctionHolder * holder,const ValueValidityPairVector & args,bool with_validity,bool with_context)1174 std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams(
1175     FunctionHolder* holder, const ValueValidityPairVector& args, bool with_validity,
1176     bool with_context) {
1177   LLVMTypes* types = generator_->types();
1178   std::vector<llvm::Value*> params;
1179 
1180   // add context if required.
1181   if (with_context) {
1182     params.push_back(arg_context_ptr_);
1183   }
1184 
1185   // if the function has holder, add the holder pointer.
1186   if (holder != nullptr) {
1187     auto ptr = types->i64_constant((int64_t)holder);
1188     params.push_back(ptr);
1189   }
1190 
1191   // build the function params, along with the validities.
1192   for (auto& pair : args) {
1193     // build value.
1194     DexPtr value_expr = pair->value_expr();
1195     value_expr->Accept(*this);
1196     LValue& result_ref = *result();
1197 
1198     // append all the parameters corresponding to this LValue.
1199     result_ref.AppendFunctionParams(&params);
1200 
1201     // build validity.
1202     if (with_validity) {
1203       llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
1204       params.push_back(validity_expr);
1205     }
1206   }
1207 
1208   return params;
1209 }
1210 
1211 // Bitwise-AND of a vector of bits to get the combined validity.
BuildCombinedValidity(const DexVector & validities)1212 llvm::Value* LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector& validities) {
1213   llvm::IRBuilder<>* builder = ir_builder();
1214   LLVMTypes* types = generator_->types();
1215 
1216   llvm::Value* isValid = types->true_constant();
1217   for (auto& dex : validities) {
1218     dex->Accept(*this);
1219     isValid = builder->CreateAnd(isValid, result()->data(), "validityBitAnd");
1220   }
1221   ADD_VISITOR_TRACE("combined validity is %T", isValid);
1222   return isValid;
1223 }
1224 
GetBufferReference(int idx,BufferType buffer_type,FieldPtr field)1225 llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buffer_type,
1226                                                         FieldPtr field) {
1227   llvm::IRBuilder<>* builder = ir_builder();
1228 
1229   // Switch to the entry block to create a reference.
1230   llvm::BasicBlock* saved_block = builder->GetInsertBlock();
1231   builder->SetInsertPoint(entry_block_);
1232 
1233   llvm::Value* slot_ref = nullptr;
1234   switch (buffer_type) {
1235     case kBufferTypeValidity:
1236       slot_ref = generator_->GetValidityReference(arg_addrs_, idx, field);
1237       break;
1238 
1239     case kBufferTypeData:
1240       slot_ref = generator_->GetDataReference(arg_addrs_, idx, field);
1241       break;
1242 
1243     case kBufferTypeOffsets:
1244       slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field);
1245       break;
1246   }
1247 
1248   // Revert to the saved block.
1249   builder->SetInsertPoint(saved_block);
1250   return slot_ref;
1251 }
1252 
GetSliceOffset(int idx)1253 llvm::Value* LLVMGenerator::Visitor::GetSliceOffset(int idx) {
1254   return slice_offsets_[idx];
1255 }
1256 
GetLocalBitMapReference(int idx)1257 llvm::Value* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) {
1258   llvm::IRBuilder<>* builder = ir_builder();
1259 
1260   // Switch to the entry block to create a reference.
1261   llvm::BasicBlock* saved_block = builder->GetInsertBlock();
1262   builder->SetInsertPoint(entry_block_);
1263 
1264   llvm::Value* slot_ref = generator_->GetLocalBitMapReference(arg_local_bitmaps_, idx);
1265 
1266   // Revert to the saved block.
1267   builder->SetInsertPoint(saved_block);
1268   return slot_ref;
1269 }
1270 
1271 /// The local bitmap is pre-filled with 1s. Clear only if invalid.
ClearLocalBitMapIfNotValid(int local_bitmap_idx,llvm::Value * is_valid)1272 void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx,
1273                                                         llvm::Value* is_valid) {
1274   llvm::Value* slot_ref = GetLocalBitMapReference(local_bitmap_idx);
1275   generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid);
1276 }
1277 
1278 // Hooks for tracing/printfs.
1279 //
1280 // replace %T with the type-specific format specifier.
1281 // For some reason, float/double literals are getting lost when printing with the generic
1282 // printf. so, use a wrapper instead.
ReplaceFormatInTrace(const std::string & in_msg,llvm::Value * value,std::string * print_fn)1283 std::string LLVMGenerator::ReplaceFormatInTrace(const std::string& in_msg,
1284                                                 llvm::Value* value,
1285                                                 std::string* print_fn) {
1286   std::string msg = in_msg;
1287   std::size_t pos = msg.find("%T");
1288   if (pos == std::string::npos) {
1289     DCHECK(0);
1290     return msg;
1291   }
1292 
1293   llvm::Type* type = value->getType();
1294   const char* fmt = "";
1295   if (type->isIntegerTy(1) || type->isIntegerTy(8) || type->isIntegerTy(16) ||
1296       type->isIntegerTy(32)) {
1297     fmt = "%d";
1298   } else if (type->isIntegerTy(64)) {
1299     // bigint
1300     fmt = "%lld";
1301   } else if (type->isFloatTy()) {
1302     // float
1303     fmt = "%f";
1304     *print_fn = "print_float";
1305   } else if (type->isDoubleTy()) {
1306     // float
1307     fmt = "%lf";
1308     *print_fn = "print_double";
1309   } else if (type->isPointerTy()) {
1310     // string
1311     fmt = "%s";
1312   } else {
1313     DCHECK(0);
1314   }
1315   msg.replace(pos, 2, fmt);
1316   return msg;
1317 }
1318 
AddTrace(const std::string & msg,llvm::Value * value)1319 void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) {
1320   if (!enable_ir_traces_) {
1321     return;
1322   }
1323 
1324   std::string dmsg = "IR_TRACE:: " + msg + "\n";
1325   std::string print_fn_name = "printf";
1326   if (value != nullptr) {
1327     dmsg = ReplaceFormatInTrace(dmsg, value, &print_fn_name);
1328   }
1329   trace_strings_.push_back(dmsg);
1330 
1331   // cast this to an llvm pointer.
1332   const char* str = trace_strings_.back().c_str();
1333   llvm::Constant* str_int_cast = types()->i64_constant((int64_t)str);
1334   llvm::Constant* str_ptr_cast =
1335       llvm::ConstantExpr::getIntToPtr(str_int_cast, types()->i8_ptr_type());
1336 
1337   std::vector<llvm::Value*> args;
1338   args.push_back(str_ptr_cast);
1339   if (value != nullptr) {
1340     args.push_back(value);
1341   }
1342   AddFunctionCall(print_fn_name, types()->i32_type(), args);
1343 }
1344 
1345 }  // namespace gandiva
1346