1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "gandiva/llvm_generator.h"
19
20 #include <fstream>
21 #include <iostream>
22 #include <sstream>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "gandiva/bitmap_accumulator.h"
28 #include "gandiva/decimal_ir.h"
29 #include "gandiva/dex.h"
30 #include "gandiva/expr_decomposer.h"
31 #include "gandiva/expression.h"
32 #include "gandiva/lvalue.h"
33
34 namespace gandiva {
35
36 #define ADD_TRACE(...) \
37 if (enable_ir_traces_) { \
38 AddTrace(__VA_ARGS__); \
39 }
40
LLVMGenerator()41 LLVMGenerator::LLVMGenerator() : enable_ir_traces_(false) {}
42
Make(std::shared_ptr<Configuration> config,std::unique_ptr<LLVMGenerator> * llvm_generator)43 Status LLVMGenerator::Make(std::shared_ptr<Configuration> config,
44 std::unique_ptr<LLVMGenerator>* llvm_generator) {
45 std::unique_ptr<LLVMGenerator> llvmgen_obj(new LLVMGenerator());
46
47 ARROW_RETURN_NOT_OK(Engine::Make(config, &(llvmgen_obj->engine_)));
48 *llvm_generator = std::move(llvmgen_obj);
49
50 return Status::OK();
51 }
52
Add(const ExpressionPtr expr,const FieldDescriptorPtr output)53 Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) {
54 int idx = static_cast<int>(compiled_exprs_.size());
55 // decompose the expression to separate out value and validities.
56 ExprDecomposer decomposer(function_registry_, annotator_);
57 ValueValidityPairPtr value_validity;
58 ARROW_RETURN_NOT_OK(decomposer.Decompose(*expr->root(), &value_validity));
59 // Generate the IR function for the decomposed expression.
60 std::unique_ptr<CompiledExpr> compiled_expr(new CompiledExpr(value_validity, output));
61 llvm::Function* ir_function = nullptr;
62 ARROW_RETURN_NOT_OK(CodeGenExprValue(value_validity->value_expr(),
63 annotator_.buffer_count(), output, idx,
64 &ir_function, selection_vector_mode_));
65 compiled_expr->SetIRFunction(selection_vector_mode_, ir_function);
66
67 compiled_exprs_.push_back(std::move(compiled_expr));
68 return Status::OK();
69 }
70
71 /// Build and optimise module for projection expression.
Build(const ExpressionVector & exprs,SelectionVector::Mode mode)72 Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode mode) {
73 selection_vector_mode_ = mode;
74 for (auto& expr : exprs) {
75 auto output = annotator_.AddOutputFieldDescriptor(expr->result());
76 ARROW_RETURN_NOT_OK(Add(expr, output));
77 }
78
79 // Compile and inject into the process' memory the generated function.
80 ARROW_RETURN_NOT_OK(engine_->FinalizeModule());
81
82 // setup the jit functions for each expression.
83 for (auto& compiled_expr : compiled_exprs_) {
84 auto ir_fn = compiled_expr->GetIRFunction(mode);
85 auto jit_fn = reinterpret_cast<EvalFunc>(engine_->CompiledFunction(ir_fn));
86 compiled_expr->SetJITFunction(selection_vector_mode_, jit_fn);
87 }
88
89 return Status::OK();
90 }
91
92 /// Execute the compiled module against the provided vectors.
Execute(const arrow::RecordBatch & record_batch,const ArrayDataVector & output_vector)93 Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
94 const ArrayDataVector& output_vector) {
95 return Execute(record_batch, nullptr, output_vector);
96 }
97
98 /// Execute the compiled module against the provided vectors based on the type of
99 /// selection vector.
Execute(const arrow::RecordBatch & record_batch,const SelectionVector * selection_vector,const ArrayDataVector & output_vector)100 Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
101 const SelectionVector* selection_vector,
102 const ArrayDataVector& output_vector) {
103 DCHECK_GT(record_batch.num_rows(), 0);
104
105 auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector);
106 DCHECK_GT(eval_batch->GetNumBuffers(), 0);
107
108 auto mode = SelectionVector::MODE_NONE;
109 if (selection_vector != nullptr) {
110 mode = selection_vector->GetMode();
111 }
112 if (mode != selection_vector_mode_) {
113 return Status::Invalid("llvm expression built for selection vector mode ",
114 selection_vector_mode_, " received vector with mode ", mode);
115 }
116
117 for (auto& compiled_expr : compiled_exprs_) {
118 // generate data/offset vectors.
119 const uint8_t* selection_buffer = nullptr;
120 auto num_output_rows = record_batch.num_rows();
121 if (selection_vector != nullptr) {
122 selection_buffer = selection_vector->GetBuffer().data();
123 num_output_rows = selection_vector->GetNumSlots();
124 }
125
126 EvalFunc jit_function = compiled_expr->GetJITFunction(mode);
127 jit_function(eval_batch->GetBufferArray(), eval_batch->GetBufferOffsetArray(),
128 eval_batch->GetLocalBitMapArray(), selection_buffer,
129 (int64_t)eval_batch->GetExecutionContext(), num_output_rows);
130
131 // check for execution errors
132 ARROW_RETURN_IF(
133 eval_batch->GetExecutionContext()->has_error(),
134 Status::ExecutionError(eval_batch->GetExecutionContext()->get_error()));
135
136 // generate validity vectors.
137 ComputeBitMapsForExpr(*compiled_expr, *eval_batch, selection_vector);
138 }
139
140 return Status::OK();
141 }
142
LoadVectorAtIndex(llvm::Value * arg_addrs,int idx,const std::string & name)143 llvm::Value* LLVMGenerator::LoadVectorAtIndex(llvm::Value* arg_addrs, int idx,
144 const std::string& name) {
145 auto* idx_val = types()->i32_constant(idx);
146 auto* offset = ir_builder()->CreateGEP(arg_addrs, idx_val, name + "_mem_addr");
147 return ir_builder()->CreateLoad(offset, name + "_mem");
148 }
149
150 /// Get reference to validity array at specified index in the args list.
GetValidityReference(llvm::Value * arg_addrs,int idx,FieldPtr field)151 llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx,
152 FieldPtr field) {
153 const std::string& name = field->name();
154 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
155 return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), name + "_varray");
156 }
157
158 /// Get reference to data array at specified index in the args list.
GetDataBufferPtrReference(llvm::Value * arg_addrs,int idx,FieldPtr field)159 llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx,
160 FieldPtr field) {
161 const std::string& name = field->name();
162 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
163 return ir_builder()->CreateIntToPtr(load, types()->i8_ptr_type(), name + "_buf_ptr");
164 }
165
166 /// Get reference to data array at specified index in the args list.
GetDataReference(llvm::Value * arg_addrs,int idx,FieldPtr field)167 llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx,
168 FieldPtr field) {
169 const std::string& name = field->name();
170 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
171 llvm::Type* base_type = types()->DataVecType(field->type());
172 llvm::Value* ret;
173 if (base_type->isPointerTy()) {
174 ret = ir_builder()->CreateIntToPtr(load, base_type, name + "_darray");
175 } else {
176 llvm::Type* pointer_type = types()->ptr_type(base_type);
177 ret = ir_builder()->CreateIntToPtr(load, pointer_type, name + "_darray");
178 }
179 return ret;
180 }
181
182 /// Get reference to offsets array at specified index in the args list.
GetOffsetsReference(llvm::Value * arg_addrs,int idx,FieldPtr field)183 llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx,
184 FieldPtr field) {
185 const std::string& name = field->name();
186 llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
187 return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray");
188 }
189
190 /// Get reference to local bitmap array at specified index in the args list.
GetLocalBitMapReference(llvm::Value * arg_bitmaps,int idx)191 llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) {
192 llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, idx, "");
193 return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(),
194 std::to_string(idx) + "_lbmap");
195 }
196
197 /// \brief Generate code for one expression.
198
199 // Sample IR code for "c1:int + c2:int"
200 //
201 // The C-code equivalent is :
202 // ------------------------------
203 // int expr_0(int64_t *addrs, int64_t *local_bitmaps,
204 // int64_t execution_context_ptr, int64_t nrecords) {
205 // int *outVec = (int *) addrs[5];
206 // int *c0Vec = (int *) addrs[1];
207 // int *c1Vec = (int *) addrs[3];
208 // for (int loop_var = 0; loop_var < nrecords; ++loop_var) {
209 // int c0 = c0Vec[loop_var];
210 // int c1 = c1Vec[loop_var];
211 // int out = c0 + c1;
212 // outVec[loop_var] = out;
213 // }
214 // }
215 //
216 // IR Code
217 // --------
218 //
219 // define i32 @expr_0(i64* %args, i64* %local_bitmaps, i64 %execution_context_ptr, , i64
220 // %nrecords) { entry:
221 // %outmemAddr = getelementptr i64, i64* %args, i32 5
222 // %outmem = load i64, i64* %outmemAddr
223 // %outVec = inttoptr i64 %outmem to i32*
224 // %c0memAddr = getelementptr i64, i64* %args, i32 1
225 // %c0mem = load i64, i64* %c0memAddr
226 // %c0Vec = inttoptr i64 %c0mem to i32*
227 // %c1memAddr = getelementptr i64, i64* %args, i32 3
228 // %c1mem = load i64, i64* %c1memAddr
229 // %c1Vec = inttoptr i64 %c1mem to i32*
230 // br label %loop
231 // loop: ; preds = %loop, %entry
232 // %loop_var = phi i64 [ 0, %entry ], [ %"loop_var+1", %loop ]
233 // %"loop_var+1" = add i64 %loop_var, 1
234 // %0 = getelementptr i32, i32* %c0Vec, i32 %loop_var
235 // %c0 = load i32, i32* %0
236 // %1 = getelementptr i32, i32* %c1Vec, i32 %loop_var
237 // %c1 = load i32, i32* %1
238 // %add_int_int = call i32 @add_int_int(i32 %c0, i32 %c1)
239 // %2 = getelementptr i32, i32* %outVec, i32 %loop_var
240 // store i32 %add_int_int, i32* %2
241 // %"loop_var < nrec" = icmp slt i64 %"loop_var+1", %nrecords
242 // br i1 %"loop_var < nrec", label %loop, label %exit
243 // exit: ; preds = %loop
244 // ret i32 0
245 // }
CodeGenExprValue(DexPtr value_expr,int buffer_count,FieldDescriptorPtr output,int suffix_idx,llvm::Function ** fn,SelectionVector::Mode selection_vector_mode)246 Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
247 FieldDescriptorPtr output, int suffix_idx,
248 llvm::Function** fn,
249 SelectionVector::Mode selection_vector_mode) {
250 llvm::IRBuilder<>* builder = ir_builder();
251 // Create fn prototype :
252 // int expr_1 (long **addrs, long *offsets, long **bitmaps,
253 // long *context_ptr, long nrec)
254 std::vector<llvm::Type*> arguments;
255 arguments.push_back(types()->i64_ptr_type()); // addrs
256 arguments.push_back(types()->i64_ptr_type()); // offsets
257 arguments.push_back(types()->i64_ptr_type()); // bitmaps
258 switch (selection_vector_mode) {
259 case SelectionVector::MODE_NONE:
260 case SelectionVector::MODE_UINT16:
261 arguments.push_back(types()->ptr_type(types()->i16_type()));
262 break;
263 case SelectionVector::MODE_UINT32:
264 arguments.push_back(types()->i32_ptr_type());
265 break;
266 case SelectionVector::MODE_UINT64:
267 arguments.push_back(types()->i64_ptr_type());
268 }
269 arguments.push_back(types()->i64_type()); // ctx_ptr
270 arguments.push_back(types()->i64_type()); // nrec
271 llvm::FunctionType* prototype =
272 llvm::FunctionType::get(types()->i32_type(), arguments, false /*isVarArg*/);
273
274 // Create fn
275 std::string func_name = "expr_" + std::to_string(suffix_idx) + "_" +
276 std::to_string(static_cast<int>(selection_vector_mode));
277 engine_->AddFunctionToCompile(func_name);
278 *fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, func_name,
279 module());
280 ARROW_RETURN_IF((*fn == nullptr), Status::CodeGenError("Error creating function."));
281
282 // Name the arguments
283 llvm::Function::arg_iterator args = (*fn)->arg_begin();
284 llvm::Value* arg_addrs = &*args;
285 arg_addrs->setName("inputs_addr");
286 ++args;
287 llvm::Value* arg_addr_offsets = &*args;
288 arg_addr_offsets->setName("inputs_addr_offsets");
289 ++args;
290 llvm::Value* arg_local_bitmaps = &*args;
291 arg_local_bitmaps->setName("local_bitmaps");
292 ++args;
293 llvm::Value* arg_selection_vector = &*args;
294 arg_selection_vector->setName("selection_vector");
295 ++args;
296 llvm::Value* arg_context_ptr = &*args;
297 arg_context_ptr->setName("context_ptr");
298 ++args;
299 llvm::Value* arg_nrecords = &*args;
300 arg_nrecords->setName("nrecords");
301
302 llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context(), "entry", *fn);
303 llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context(), "loop", *fn);
304 llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context(), "exit", *fn);
305
306 // Add reference to output vector (in entry block)
307 builder->SetInsertPoint(loop_entry);
308 llvm::Value* output_ref =
309 GetDataReference(arg_addrs, output->data_idx(), output->field());
310 llvm::Value* output_buffer_ptr_ref = GetDataBufferPtrReference(
311 arg_addrs, output->data_buffer_ptr_idx(), output->field());
312 llvm::Value* output_offset_ref =
313 GetOffsetsReference(arg_addrs, output->offsets_idx(), output->field());
314
315 std::vector<llvm::Value*> slice_offsets;
316 for (int idx = 0; idx < buffer_count; idx++) {
317 auto offsetAddr = builder->CreateGEP(arg_addr_offsets, types()->i32_constant(idx));
318 auto offset = builder->CreateLoad(offsetAddr);
319 slice_offsets.push_back(offset);
320 }
321
322 // Loop body
323 builder->SetInsertPoint(loop_body);
324
325 // define loop_var : start with 0, +1 after each iter
326 llvm::PHINode* loop_var = builder->CreatePHI(types()->i64_type(), 2, "loop_var");
327
328 llvm::Value* position_var = loop_var;
329 if (selection_vector_mode != SelectionVector::MODE_NONE) {
330 position_var = builder->CreateIntCast(
331 builder->CreateLoad(builder->CreateGEP(arg_selection_vector, loop_var),
332 "uncasted_position_var"),
333 types()->i64_type(), true, "position_var");
334 }
335
336 // The visitor can add code to both the entry/loop blocks.
337 Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, slice_offsets,
338 arg_context_ptr, position_var);
339 value_expr->Accept(visitor);
340 LValuePtr output_value = visitor.result();
341
342 // The "current" block may have changed due to code generation in the visitor.
343 llvm::BasicBlock* loop_body_tail = builder->GetInsertBlock();
344
345 // add jump to "loop block" at the end of the "setup block".
346 builder->SetInsertPoint(loop_entry);
347 builder->CreateBr(loop_body);
348
349 // save the value in the output vector.
350 builder->SetInsertPoint(loop_body_tail);
351
352 auto output_type_id = output->Type()->id();
353 if (output_type_id == arrow::Type::BOOL) {
354 SetPackedBitValue(output_ref, loop_var, output_value->data());
355 } else if (arrow::is_primitive(output_type_id) ||
356 output_type_id == arrow::Type::DECIMAL) {
357 llvm::Value* slot_offset = builder->CreateGEP(output_ref, loop_var);
358 builder->CreateStore(output_value->data(), slot_offset);
359 } else if (arrow::is_binary_like(output_type_id)) {
360 // Var-len output. Make a function call to populate the data.
361 // if there is an error, the fn sets it in the context. And, will be returned at the
362 // end of this row batch.
363 AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(),
364 {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var,
365 output_value->data(), output_value->length()});
366 } else {
367 return Status::NotImplemented("output type ", output->Type()->ToString(),
368 " not supported");
369 }
370 ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data());
371
372 if (visitor.has_arena_allocs()) {
373 // Reset allocations to avoid excessive memory usage. Once the result is copied to
374 // the output vector (store instruction above), any memory allocations in this
375 // iteration of the loop are no longer needed.
376 std::vector<llvm::Value*> reset_args;
377 reset_args.push_back(arg_context_ptr);
378 AddFunctionCall("gdv_fn_context_arena_reset", types()->void_type(), reset_args);
379 }
380
381 // check loop_var
382 loop_var->addIncoming(types()->i64_constant(0), loop_entry);
383 llvm::Value* loop_update =
384 builder->CreateAdd(loop_var, types()->i64_constant(1), "loop_var+1");
385 loop_var->addIncoming(loop_update, loop_body_tail);
386
387 llvm::Value* loop_var_check =
388 builder->CreateICmpSLT(loop_update, arg_nrecords, "loop_var < nrec");
389 builder->CreateCondBr(loop_var_check, loop_body, loop_exit);
390
391 // Loop exit
392 builder->SetInsertPoint(loop_exit);
393 builder->CreateRet(types()->i32_constant(0));
394 return Status::OK();
395 }
396
397 /// Return value of a bit in bitMap.
GetPackedBitValue(llvm::Value * bitmap,llvm::Value * position)398 llvm::Value* LLVMGenerator::GetPackedBitValue(llvm::Value* bitmap,
399 llvm::Value* position) {
400 ADD_TRACE("fetch bit at position %T", position);
401
402 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
403 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
404 return AddFunctionCall("bitMapGetBit", types()->i1_type(), {bitmap8, position});
405 }
406
407 /// Set the value of a bit in bitMap.
SetPackedBitValue(llvm::Value * bitmap,llvm::Value * position,llvm::Value * value)408 void LLVMGenerator::SetPackedBitValue(llvm::Value* bitmap, llvm::Value* position,
409 llvm::Value* value) {
410 ADD_TRACE("set bit at position %T", position);
411 ADD_TRACE(" to value %T ", value);
412
413 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
414 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
415 AddFunctionCall("bitMapSetBit", types()->void_type(), {bitmap8, position, value});
416 }
417
418 /// Return value of a bit in validity bitMap (handles null bitmaps too).
GetPackedValidityBitValue(llvm::Value * bitmap,llvm::Value * position)419 llvm::Value* LLVMGenerator::GetPackedValidityBitValue(llvm::Value* bitmap,
420 llvm::Value* position) {
421 ADD_TRACE("fetch validity bit at position %T", position);
422
423 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
424 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
425 return AddFunctionCall("bitMapValidityGetBit", types()->i1_type(), {bitmap8, position});
426 }
427
428 /// Clear the bit in bitMap if value = false.
ClearPackedBitValueIfFalse(llvm::Value * bitmap,llvm::Value * position,llvm::Value * value)429 void LLVMGenerator::ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position,
430 llvm::Value* value) {
431 ADD_TRACE("ClearIfFalse bit at position %T", position);
432 ADD_TRACE(" value %T ", value);
433
434 llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
435 bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
436 AddFunctionCall("bitMapClearBitIfFalse", types()->void_type(),
437 {bitmap8, position, value});
438 }
439
440 /// Extract the bitmap addresses, and do an intersection.
ComputeBitMapsForExpr(const CompiledExpr & compiled_expr,const EvalBatch & eval_batch,const SelectionVector * selection_vector)441 void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr,
442 const EvalBatch& eval_batch,
443 const SelectionVector* selection_vector) {
444 auto validities = compiled_expr.value_validity()->validity_exprs();
445
446 // Extract all the source bitmap addresses.
447 BitMapAccumulator accumulator(eval_batch);
448 for (auto& validity_dex : validities) {
449 validity_dex->Accept(accumulator);
450 }
451
452 // Extract the destination bitmap address.
453 int out_idx = compiled_expr.output()->validity_idx();
454 uint8_t* dst_bitmap = eval_batch.GetBuffer(out_idx);
455 // Compute the destination bitmap.
456 if (selection_vector == nullptr) {
457 accumulator.ComputeResult(dst_bitmap);
458 } else {
459 /// The output bitmap is an intersection of some input/local bitmaps. However, with a
460 /// selection vector, only the bits corresponding to the indices in the selection
461 /// vector need to set in the output bitmap. This is done in two steps :
462 ///
463 /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap.
464 /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap.
465 LocalBitMapsHolder bit_map_holder(eval_batch.num_records(), 1);
466 uint8_t* temp_bitmap = bit_map_holder.GetLocalBitMap(0);
467 accumulator.ComputeResult(temp_bitmap);
468
469 auto num_out_records = selection_vector->GetNumSlots();
470 // the memset isn't required, doing it just for valgrind.
471 memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_out_records));
472 for (auto i = 0; i < num_out_records; ++i) {
473 auto bit = arrow::BitUtil::GetBit(temp_bitmap, selection_vector->GetIndex(i));
474 arrow::BitUtil::SetBitTo(dst_bitmap, i, bit);
475 }
476 }
477 }
478
AddFunctionCall(const std::string & full_name,llvm::Type * ret_type,const std::vector<llvm::Value * > & args)479 llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,
480 llvm::Type* ret_type,
481 const std::vector<llvm::Value*>& args) {
482 // find the llvm function.
483 llvm::Function* fn = module()->getFunction(full_name);
484 DCHECK_NE(fn, nullptr) << "missing function " << full_name;
485
486 if (enable_ir_traces_ && !full_name.compare("printf") &&
487 !full_name.compare("printff")) {
488 // Trace for debugging
489 ADD_TRACE("invoke native fn " + full_name);
490 }
491
492 // build a call to the llvm function.
493 llvm::Value* value;
494 if (ret_type->isVoidTy()) {
495 // void functions can't have a name for the call.
496 value = ir_builder()->CreateCall(fn, args);
497 } else {
498 value = ir_builder()->CreateCall(fn, args, full_name);
499 DCHECK(value->getType() == ret_type);
500 }
501
502 return value;
503 }
504
BuildDecimalLValue(llvm::Value * value,DataTypePtr arrow_type)505 std::shared_ptr<DecimalLValue> LLVMGenerator::BuildDecimalLValue(llvm::Value* value,
506 DataTypePtr arrow_type) {
507 // only decimals of size 128-bit supported.
508 DCHECK(is_decimal_128(arrow_type));
509 auto decimal_type =
510 arrow::internal::checked_cast<arrow::DecimalType*>(arrow_type.get());
511 return std::make_shared<DecimalLValue>(value, nullptr,
512 types()->i32_constant(decimal_type->precision()),
513 types()->i32_constant(decimal_type->scale()));
514 }
515
516 #define ADD_VISITOR_TRACE(...) \
517 if (generator_->enable_ir_traces_) { \
518 generator_->AddTrace(__VA_ARGS__); \
519 }
520
521 // Visitor for generating the code for a decomposed expression.
Visitor(LLVMGenerator * generator,llvm::Function * function,llvm::BasicBlock * entry_block,llvm::Value * arg_addrs,llvm::Value * arg_local_bitmaps,std::vector<llvm::Value * > slice_offsets,llvm::Value * arg_context_ptr,llvm::Value * loop_var)522 LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* function,
523 llvm::BasicBlock* entry_block, llvm::Value* arg_addrs,
524 llvm::Value* arg_local_bitmaps,
525 std::vector<llvm::Value*> slice_offsets,
526 llvm::Value* arg_context_ptr, llvm::Value* loop_var)
527 : generator_(generator),
528 function_(function),
529 entry_block_(entry_block),
530 arg_addrs_(arg_addrs),
531 arg_local_bitmaps_(arg_local_bitmaps),
532 slice_offsets_(slice_offsets),
533 arg_context_ptr_(arg_context_ptr),
534 loop_var_(loop_var),
535 has_arena_allocs_(false) {
536 ADD_VISITOR_TRACE("Iteration %T", loop_var);
537 }
538
Visit(const VectorReadFixedLenValueDex & dex)539 void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) {
540 llvm::IRBuilder<>* builder = ir_builder();
541 llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
542 llvm::Value* slot_index = builder->CreateAdd(loop_var_, GetSliceOffset(dex.DataIdx()));
543 llvm::Value* slot_value;
544 std::shared_ptr<LValue> lvalue;
545
546 switch (dex.FieldType()->id()) {
547 case arrow::Type::BOOL:
548 slot_value = generator_->GetPackedBitValue(slot_ref, slot_index);
549 lvalue = std::make_shared<LValue>(slot_value);
550 break;
551
552 case arrow::Type::DECIMAL: {
553 auto slot_offset = builder->CreateGEP(slot_ref, slot_index);
554 slot_value = builder->CreateLoad(slot_offset, dex.FieldName());
555 lvalue = generator_->BuildDecimalLValue(slot_value, dex.FieldType());
556 break;
557 }
558
559 default: {
560 auto slot_offset = builder->CreateGEP(slot_ref, slot_index);
561 slot_value = builder->CreateLoad(slot_offset, dex.FieldName());
562 lvalue = std::make_shared<LValue>(slot_value);
563 break;
564 }
565 }
566 ADD_VISITOR_TRACE("visit fixed-len data vector " + dex.FieldName() + " value %T",
567 slot_value);
568 result_ = lvalue;
569 }
570
Visit(const VectorReadVarLenValueDex & dex)571 void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) {
572 llvm::IRBuilder<>* builder = ir_builder();
573 llvm::Value* slot;
574
575 // compute len from the offsets array.
576 llvm::Value* offsets_slot_ref =
577 GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field());
578 llvm::Value* offsets_slot_index =
579 builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx()));
580
581 // => offset_start = offsets[loop_var]
582 slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index);
583 llvm::Value* offset_start = builder->CreateLoad(slot, "offset_start");
584
585 // => offset_end = offsets[loop_var + 1]
586 llvm::Value* offsets_slot_index_next = builder->CreateAdd(
587 offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1");
588 slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index_next);
589 llvm::Value* offset_end = builder->CreateLoad(slot, "offset_end");
590
591 // => len_value = offset_end - offset_start
592 llvm::Value* len_value =
593 builder->CreateSub(offset_end, offset_start, dex.FieldName() + "Len");
594
595 // get the data from the data array, at offset 'offset_start'.
596 llvm::Value* data_slot_ref =
597 GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
598 llvm::Value* data_value = builder->CreateGEP(data_slot_ref, offset_start);
599 ADD_VISITOR_TRACE("visit var-len data vector " + dex.FieldName() + " len %T",
600 len_value);
601 result_.reset(new LValue(data_value, len_value));
602 }
603
Visit(const VectorReadValidityDex & dex)604 void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) {
605 llvm::IRBuilder<>* builder = ir_builder();
606 llvm::Value* slot_ref =
607 GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field());
608 llvm::Value* slot_index =
609 builder->CreateAdd(loop_var_, GetSliceOffset(dex.ValidityIdx()));
610 llvm::Value* validity = generator_->GetPackedValidityBitValue(slot_ref, slot_index);
611
612 ADD_VISITOR_TRACE("visit validity vector " + dex.FieldName() + " value %T", validity);
613 result_.reset(new LValue(validity));
614 }
615
Visit(const LocalBitMapValidityDex & dex)616 void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) {
617 llvm::Value* slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx());
618 llvm::Value* validity = generator_->GetPackedBitValue(slot_ref, loop_var_);
619
620 ADD_VISITOR_TRACE(
621 "visit local bitmap " + std::to_string(dex.local_bitmap_idx()) + " value %T",
622 validity);
623 result_.reset(new LValue(validity));
624 }
625
Visit(const TrueDex & dex)626 void LLVMGenerator::Visitor::Visit(const TrueDex& dex) {
627 result_.reset(new LValue(generator_->types()->true_constant()));
628 }
629
Visit(const FalseDex & dex)630 void LLVMGenerator::Visitor::Visit(const FalseDex& dex) {
631 result_.reset(new LValue(generator_->types()->false_constant()));
632 }
633
Visit(const LiteralDex & dex)634 void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) {
635 LLVMTypes* types = generator_->types();
636 llvm::Value* value = nullptr;
637 llvm::Value* len = nullptr;
638
639 switch (dex.type()->id()) {
640 case arrow::Type::BOOL:
641 value = types->i1_constant(arrow::util::get<bool>(dex.holder()));
642 break;
643
644 case arrow::Type::UINT8:
645 value = types->i8_constant(arrow::util::get<uint8_t>(dex.holder()));
646 break;
647
648 case arrow::Type::UINT16:
649 value = types->i16_constant(arrow::util::get<uint16_t>(dex.holder()));
650 break;
651
652 case arrow::Type::UINT32:
653 value = types->i32_constant(arrow::util::get<uint32_t>(dex.holder()));
654 break;
655
656 case arrow::Type::UINT64:
657 value = types->i64_constant(arrow::util::get<uint64_t>(dex.holder()));
658 break;
659
660 case arrow::Type::INT8:
661 value = types->i8_constant(arrow::util::get<int8_t>(dex.holder()));
662 break;
663
664 case arrow::Type::INT16:
665 value = types->i16_constant(arrow::util::get<int16_t>(dex.holder()));
666 break;
667
668 case arrow::Type::FLOAT:
669 value = types->float_constant(arrow::util::get<float>(dex.holder()));
670 break;
671
672 case arrow::Type::DOUBLE:
673 value = types->double_constant(arrow::util::get<double>(dex.holder()));
674 break;
675
676 case arrow::Type::STRING:
677 case arrow::Type::BINARY: {
678 const std::string& str = arrow::util::get<std::string>(dex.holder());
679
680 llvm::Constant* str_int_cast = types->i64_constant((int64_t)str.c_str());
681 value = llvm::ConstantExpr::getIntToPtr(str_int_cast, types->i8_ptr_type());
682 len = types->i32_constant(static_cast<int32_t>(str.length()));
683 break;
684 }
685
686 case arrow::Type::INT32:
687 case arrow::Type::DATE32:
688 case arrow::Type::TIME32:
689 case arrow::Type::INTERVAL_MONTHS:
690 value = types->i32_constant(arrow::util::get<int32_t>(dex.holder()));
691 break;
692
693 case arrow::Type::INT64:
694 case arrow::Type::DATE64:
695 case arrow::Type::TIME64:
696 case arrow::Type::TIMESTAMP:
697 case arrow::Type::INTERVAL_DAY_TIME:
698 value = types->i64_constant(arrow::util::get<int64_t>(dex.holder()));
699 break;
700
701 case arrow::Type::DECIMAL: {
702 // build code for struct
703 auto scalar = arrow::util::get<DecimalScalar128>(dex.holder());
704 // ConstantInt doesn't have a get method that takes int128 or a pair of int64. so,
705 // passing the string representation instead.
706 auto int128_value =
707 llvm::ConstantInt::get(llvm::Type::getInt128Ty(*generator_->context()),
708 Decimal128(scalar.value()).ToIntegerString(), 10);
709 auto type = arrow::decimal(scalar.precision(), scalar.scale());
710 auto lvalue = generator_->BuildDecimalLValue(int128_value, type);
711 // set it as the l-value and return.
712 result_ = lvalue;
713 return;
714 }
715
716 default:
717 DCHECK(0);
718 }
719 ADD_VISITOR_TRACE("visit Literal %T", value);
720 result_.reset(new LValue(value, len));
721 }
722
Visit(const NonNullableFuncDex & dex)723 void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
724 const std::string& function_name = dex.func_descriptor()->name();
725 ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name);
726
727 const NativeFunction* native_function = dex.native_function();
728
729 // build the function params (ignore validity).
730 auto params = BuildParams(dex.function_holder().get(), dex.args(), false,
731 native_function->NeedsContext());
732
733 auto arrow_return_type = dex.func_descriptor()->return_type();
734 if (native_function->CanReturnErrors()) {
735 // slow path : if a function can return errors, skip invoking the function
736 // unless all of the input args are valid. Otherwise, it can cause spurious errors.
737
738 llvm::IRBuilder<>* builder = ir_builder();
739 LLVMTypes* types = generator_->types();
740 auto arrow_type_id = arrow_return_type->id();
741 auto result_type = types->IRType(arrow_type_id);
742
743 // Build combined validity of the args.
744 llvm::Value* is_valid = types->true_constant();
745 for (auto& pair : dex.args()) {
746 auto arg_validity = BuildCombinedValidity(pair->validity_exprs());
747 is_valid = builder->CreateAnd(is_valid, arg_validity, "validityBitAnd");
748 }
749
750 // then block
751 auto then_lambda = [&] {
752 ADD_VISITOR_TRACE("fn " + function_name +
753 " can return errors : all args valid, invoke fn");
754 return BuildFunctionCall(native_function, arrow_return_type, ¶ms);
755 };
756
757 // else block
758 auto else_lambda = [&] {
759 ADD_VISITOR_TRACE("fn " + function_name +
760 " can return errors : not all args valid, return dummy value");
761 llvm::Value* else_value = types->NullConstant(result_type);
762 llvm::Value* else_value_len = nullptr;
763 if (arrow::is_binary_like(arrow_type_id)) {
764 else_value_len = types->i32_constant(0);
765 }
766 return std::make_shared<LValue>(else_value, else_value_len);
767 };
768
769 result_ = BuildIfElse(is_valid, then_lambda, else_lambda, arrow_return_type);
770 } else {
771 // fast path : invoke function without computing validities.
772 result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms);
773 }
774 }
775
Visit(const NullableNeverFuncDex & dex)776 void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) {
777 ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name());
778 const NativeFunction* native_function = dex.native_function();
779
780 // build function params along with validity.
781 auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
782 native_function->NeedsContext());
783
784 auto arrow_return_type = dex.func_descriptor()->return_type();
785 result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms);
786 }
787
Visit(const NullableInternalFuncDex & dex)788 void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
789 ADD_VISITOR_TRACE("visit NullableInternal base function " +
790 dex.func_descriptor()->name());
791 llvm::IRBuilder<>* builder = ir_builder();
792 LLVMTypes* types = generator_->types();
793
794 const NativeFunction* native_function = dex.native_function();
795
796 // build function params along with validity.
797 auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
798 native_function->NeedsContext());
799
800 // add an extra arg for validity (allocated on stack).
801 llvm::AllocaInst* result_valid_ptr =
802 new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_);
803 params.push_back(result_valid_ptr);
804
805 auto arrow_return_type = dex.func_descriptor()->return_type();
806 result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms);
807
808 // load the result validity and truncate to i1.
809 llvm::Value* result_valid_i8 = builder->CreateLoad(result_valid_ptr);
810 llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type());
811
812 // set validity bit in the local bitmap.
813 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid);
814 }
815
Visit(const IfDex & dex)816 void LLVMGenerator::Visitor::Visit(const IfDex& dex) {
817 ADD_VISITOR_TRACE("visit IfExpression");
818 llvm::IRBuilder<>* builder = ir_builder();
819
820 // Evaluate condition.
821 LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv());
822
823 // Check if the result is valid, and there is match.
824 llvm::Value* validAndMatched =
825 builder->CreateAnd(if_condition->data(), if_condition->validity(), "validAndMatch");
826
827 // then block
828 auto then_lambda = [&] {
829 ADD_VISITOR_TRACE("branch to then block");
830 LValuePtr then_lvalue = BuildValueAndValidity(dex.then_vv());
831 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), then_lvalue->validity());
832 ADD_VISITOR_TRACE("IfExpression result validity %T in matching then",
833 then_lvalue->validity());
834 return then_lvalue;
835 };
836
837 // else block
838 auto else_lambda = [&] {
839 LValuePtr else_lvalue;
840 if (dex.is_terminal_else()) {
841 ADD_VISITOR_TRACE("branch to terminal else block");
842
843 else_lvalue = BuildValueAndValidity(dex.else_vv());
844 // update the local bitmap with the validity.
845 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), else_lvalue->validity());
846 ADD_VISITOR_TRACE("IfExpression result validity %T in terminal else",
847 else_lvalue->validity());
848 } else {
849 ADD_VISITOR_TRACE("branch to non-terminal else block");
850
851 // this is a non-terminal else. let the child (nested if/else) handle validity.
852 auto value_expr = dex.else_vv().value_expr();
853 value_expr->Accept(*this);
854 else_lvalue = result();
855 }
856 return else_lvalue;
857 };
858
859 // build the if-else condition.
860 result_ = BuildIfElse(validAndMatched, then_lambda, else_lambda, dex.result_type());
861 if (arrow::is_binary_like(dex.result_type()->id())) {
862 ADD_VISITOR_TRACE("IfElse result length %T", result_->length());
863 }
864 ADD_VISITOR_TRACE("IfElse result value %T", result_->data());
865 }
866
867 // Boolean AND
868 // if any arg is valid and false,
869 // short-circuit and return FALSE (value=false, valid=true)
870 // else if all args are valid and true
871 // return TRUE (value=true, valid=true)
872 // else
873 // return NULL (value=true, valid=false)
874
Visit(const BooleanAndDex & dex)875 void LLVMGenerator::Visitor::Visit(const BooleanAndDex& dex) {
876 ADD_VISITOR_TRACE("visit BooleanAndExpression");
877 llvm::IRBuilder<>* builder = ir_builder();
878 LLVMTypes* types = generator_->types();
879 llvm::LLVMContext* context = generator_->context();
880
881 // Create blocks for short-circuit.
882 llvm::BasicBlock* short_circuit_bb =
883 llvm::BasicBlock::Create(*context, "short_circuit", function_);
884 llvm::BasicBlock* non_short_circuit_bb =
885 llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
886 llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
887
888 llvm::Value* all_exprs_valid = types->true_constant();
889 for (auto& pair : dex.args()) {
890 LValuePtr current = BuildValueAndValidity(*pair);
891
892 ADD_VISITOR_TRACE("BooleanAndExpression arg value %T", current->data());
893 ADD_VISITOR_TRACE("BooleanAndExpression arg validity %T", current->validity());
894
895 // short-circuit if valid and false
896 llvm::Value* is_false = builder->CreateNot(current->data());
897 llvm::Value* valid_and_false =
898 builder->CreateAnd(is_false, current->validity(), "valid_and_false");
899
900 llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
901 builder->CreateCondBr(valid_and_false, short_circuit_bb, else_bb);
902
903 // Emit the else block.
904 builder->SetInsertPoint(else_bb);
905 // remember if any nulls were encountered.
906 all_exprs_valid =
907 builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
908 // continue to evaluate the next pair in list.
909 }
910 builder->CreateBr(non_short_circuit_bb);
911
912 // Short-circuit case (at least one of the expressions is valid and false).
913 // No need to set validity bit (valid by default).
914 builder->SetInsertPoint(short_circuit_bb);
915 ADD_VISITOR_TRACE("BooleanAndExpression result value false");
916 ADD_VISITOR_TRACE("BooleanAndExpression result validity true");
917 builder->CreateBr(merge_bb);
918
919 // non short-circuit case (All expressions are either true or null).
920 // result valid if all of the exprs are non-null.
921 builder->SetInsertPoint(non_short_circuit_bb);
922 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
923 ADD_VISITOR_TRACE("BooleanAndExpression result value true");
924 ADD_VISITOR_TRACE("BooleanAndExpression result validity %T", all_exprs_valid);
925 builder->CreateBr(merge_bb);
926
927 builder->SetInsertPoint(merge_bb);
928 llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
929 result_value->addIncoming(types->false_constant(), short_circuit_bb);
930 result_value->addIncoming(types->true_constant(), non_short_circuit_bb);
931 result_.reset(new LValue(result_value));
932 }
933
934 // Boolean OR
935 // if any arg is valid and true,
936 // short-circuit and return TRUE (value=true, valid=true)
937 // else if all args are valid and false
938 // return FALSE (value=false, valid=true)
939 // else
940 // return NULL (value=false, valid=false)
941
Visit(const BooleanOrDex & dex)942 void LLVMGenerator::Visitor::Visit(const BooleanOrDex& dex) {
943 ADD_VISITOR_TRACE("visit BooleanOrExpression");
944 llvm::IRBuilder<>* builder = ir_builder();
945 LLVMTypes* types = generator_->types();
946 llvm::LLVMContext* context = generator_->context();
947
948 // Create blocks for short-circuit.
949 llvm::BasicBlock* short_circuit_bb =
950 llvm::BasicBlock::Create(*context, "short_circuit", function_);
951 llvm::BasicBlock* non_short_circuit_bb =
952 llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
953 llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
954
955 llvm::Value* all_exprs_valid = types->true_constant();
956 for (auto& pair : dex.args()) {
957 LValuePtr current = BuildValueAndValidity(*pair);
958
959 ADD_VISITOR_TRACE("BooleanOrExpression arg value %T", current->data());
960 ADD_VISITOR_TRACE("BooleanOrExpression arg validity %T", current->validity());
961
962 // short-circuit if valid and true.
963 llvm::Value* valid_and_true =
964 builder->CreateAnd(current->data(), current->validity(), "valid_and_true");
965
966 llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
967 builder->CreateCondBr(valid_and_true, short_circuit_bb, else_bb);
968
969 // Emit the else block.
970 builder->SetInsertPoint(else_bb);
971 // remember if any nulls were encountered.
972 all_exprs_valid =
973 builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
974 // continue to evaluate the next pair in list.
975 }
976 builder->CreateBr(non_short_circuit_bb);
977
978 // Short-circuit case (at least one of the expressions is valid and true).
979 // No need to set validity bit (valid by default).
980 builder->SetInsertPoint(short_circuit_bb);
981 ADD_VISITOR_TRACE("BooleanOrExpression result value true");
982 ADD_VISITOR_TRACE("BooleanOrExpression result validity true");
983 builder->CreateBr(merge_bb);
984
985 // non short-circuit case (All expressions are either false or null).
986 // result valid if all of the exprs are non-null.
987 builder->SetInsertPoint(non_short_circuit_bb);
988 ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
989 ADD_VISITOR_TRACE("BooleanOrExpression result value false");
990 ADD_VISITOR_TRACE("BooleanOrExpression result validity %T", all_exprs_valid);
991 builder->CreateBr(merge_bb);
992
993 builder->SetInsertPoint(merge_bb);
994 llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
995 result_value->addIncoming(types->true_constant(), short_circuit_bb);
996 result_value->addIncoming(types->false_constant(), non_short_circuit_bb);
997 result_.reset(new LValue(result_value));
998 }
999
Visit(const InExprDexBase<int32_t> & dex)1000 void LLVMGenerator::Visitor::Visit(const InExprDexBase<int32_t>& dex) {
1001 VisitInExpression<int32_t>(dex);
1002 }
1003
Visit(const InExprDexBase<int64_t> & dex)1004 void LLVMGenerator::Visitor::Visit(const InExprDexBase<int64_t>& dex) {
1005 VisitInExpression<int64_t>(dex);
1006 }
1007
Visit(const InExprDexBase<std::string> & dex)1008 void LLVMGenerator::Visitor::Visit(const InExprDexBase<std::string>& dex) {
1009 VisitInExpression<std::string>(dex);
1010 }
1011
1012 template <typename Type>
VisitInExpression(const InExprDexBase<Type> & dex)1013 void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase<Type>& dex) {
1014 ADD_VISITOR_TRACE("visit In Expression");
1015 LLVMTypes* types = generator_->types();
1016 std::vector<llvm::Value*> params;
1017
1018 const InExprDex<Type>& dex_instance = dynamic_cast<const InExprDex<Type>&>(dex);
1019 /* add the holder at the beginning */
1020 llvm::Constant* ptr_int_cast =
1021 types->i64_constant((int64_t)(dex_instance.in_holder().get()));
1022 params.push_back(ptr_int_cast);
1023
1024 /* eval expr result */
1025 for (auto& pair : dex.args()) {
1026 DexPtr value_expr = pair->value_expr();
1027 value_expr->Accept(*this);
1028 LValue& result_ref = *result();
1029 params.push_back(result_ref.data());
1030
1031 /* length if the result is a string */
1032 if (result_ref.length() != nullptr) {
1033 params.push_back(result_ref.length());
1034 }
1035
1036 /* push the validity of eval expr result */
1037 llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
1038 params.push_back(validity_expr);
1039 }
1040
1041 llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL);
1042
1043 llvm::Value* value =
1044 generator_->AddFunctionCall(dex.runtime_function(), ret_type, params);
1045 result_.reset(new LValue(value));
1046 }
1047
BuildIfElse(llvm::Value * condition,std::function<LValuePtr ()> then_func,std::function<LValuePtr ()> else_func,DataTypePtr result_type)1048 LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition,
1049 std::function<LValuePtr()> then_func,
1050 std::function<LValuePtr()> else_func,
1051 DataTypePtr result_type) {
1052 llvm::IRBuilder<>* builder = ir_builder();
1053 llvm::LLVMContext* context = generator_->context();
1054 LLVMTypes* types = generator_->types();
1055
1056 // Create blocks for the then, else and merge cases.
1057 llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context, "then", function_);
1058 llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
1059 llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
1060
1061 builder->CreateCondBr(condition, then_bb, else_bb);
1062
1063 // Emit the then block.
1064 builder->SetInsertPoint(then_bb);
1065 LValuePtr then_lvalue = then_func();
1066 builder->CreateBr(merge_bb);
1067
1068 // refresh then_bb for phi (could have changed due to code generation of then_vv).
1069 then_bb = builder->GetInsertBlock();
1070
1071 // Emit the else block.
1072 builder->SetInsertPoint(else_bb);
1073 LValuePtr else_lvalue = else_func();
1074 builder->CreateBr(merge_bb);
1075
1076 // refresh else_bb for phi (could have changed due to code generation of else_vv).
1077 else_bb = builder->GetInsertBlock();
1078
1079 // Emit the merge block.
1080 builder->SetInsertPoint(merge_bb);
1081 auto llvm_type = types->IRType(result_type->id());
1082 llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value");
1083 result_value->addIncoming(then_lvalue->data(), then_bb);
1084 result_value->addIncoming(else_lvalue->data(), else_bb);
1085
1086 LValuePtr ret;
1087 switch (result_type->id()) {
1088 case arrow::Type::STRING: {
1089 llvm::PHINode* result_length;
1090 result_length = builder->CreatePHI(types->i32_type(), 2, "res_length");
1091 result_length->addIncoming(then_lvalue->length(), then_bb);
1092 result_length->addIncoming(else_lvalue->length(), else_bb);
1093 ret = std::make_shared<LValue>(result_value, result_length);
1094 break;
1095 }
1096
1097 case arrow::Type::DECIMAL:
1098 ret = generator_->BuildDecimalLValue(result_value, result_type);
1099 break;
1100
1101 default:
1102 ret = std::make_shared<LValue>(result_value);
1103 break;
1104 }
1105 return ret;
1106 }
1107
BuildValueAndValidity(const ValueValidityPair & pair)1108 LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair& pair) {
1109 // generate code for value
1110 auto value_expr = pair.value_expr();
1111 value_expr->Accept(*this);
1112 auto value = result()->data();
1113 auto length = result()->length();
1114
1115 // generate code for validity
1116 auto validity = BuildCombinedValidity(pair.validity_exprs());
1117
1118 return std::make_shared<LValue>(value, length, validity);
1119 }
1120
BuildFunctionCall(const NativeFunction * func,DataTypePtr arrow_return_type,std::vector<llvm::Value * > * params)1121 LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
1122 DataTypePtr arrow_return_type,
1123 std::vector<llvm::Value*>* params) {
1124 auto types = generator_->types();
1125 auto arrow_return_type_id = arrow_return_type->id();
1126 auto llvm_return_type = types->IRType(arrow_return_type_id);
1127 DecimalIR decimalIR(generator_->engine_.get());
1128
1129 if (arrow_return_type_id == arrow::Type::DECIMAL) {
1130 // For decimal fns, the output precision/scale are passed along as parameters.
1131 //
1132 // convert from this :
1133 // out = add_decimal(v1, p1, s1, v2, p2, s2)
1134 // to:
1135 // out = add_decimal(v1, p1, s1, v2, p2, s2, out_p, out_s)
1136
1137 // Append the out_precision and out_scale
1138 auto ret_lvalue = generator_->BuildDecimalLValue(nullptr, arrow_return_type);
1139 params->push_back(ret_lvalue->precision());
1140 params->push_back(ret_lvalue->scale());
1141
1142 // Make the function call
1143 auto out = decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params);
1144 ret_lvalue->set_data(out);
1145 return std::move(ret_lvalue);
1146 } else {
1147 bool isDecimalFunction = false;
1148 for (auto& arg : *params) {
1149 if (arg->getType() == types->i128_type()) {
1150 isDecimalFunction = true;
1151 }
1152 }
1153 // add extra arg for return length for variable len return types (allocated on stack).
1154 llvm::AllocaInst* result_len_ptr = nullptr;
1155 if (arrow::is_binary_like(arrow_return_type_id)) {
1156 result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
1157 "result_len", entry_block_);
1158 params->push_back(result_len_ptr);
1159 has_arena_allocs_ = true;
1160 }
1161
1162 // Make the function call
1163 llvm::IRBuilder<>* builder = ir_builder();
1164 auto value =
1165 isDecimalFunction
1166 ? decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params)
1167 : generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
1168 auto value_len =
1169 (result_len_ptr == nullptr) ? nullptr : builder->CreateLoad(result_len_ptr);
1170 return std::make_shared<LValue>(value, value_len);
1171 }
1172 }
1173
BuildParams(FunctionHolder * holder,const ValueValidityPairVector & args,bool with_validity,bool with_context)1174 std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams(
1175 FunctionHolder* holder, const ValueValidityPairVector& args, bool with_validity,
1176 bool with_context) {
1177 LLVMTypes* types = generator_->types();
1178 std::vector<llvm::Value*> params;
1179
1180 // add context if required.
1181 if (with_context) {
1182 params.push_back(arg_context_ptr_);
1183 }
1184
1185 // if the function has holder, add the holder pointer.
1186 if (holder != nullptr) {
1187 auto ptr = types->i64_constant((int64_t)holder);
1188 params.push_back(ptr);
1189 }
1190
1191 // build the function params, along with the validities.
1192 for (auto& pair : args) {
1193 // build value.
1194 DexPtr value_expr = pair->value_expr();
1195 value_expr->Accept(*this);
1196 LValue& result_ref = *result();
1197
1198 // append all the parameters corresponding to this LValue.
1199 result_ref.AppendFunctionParams(¶ms);
1200
1201 // build validity.
1202 if (with_validity) {
1203 llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
1204 params.push_back(validity_expr);
1205 }
1206 }
1207
1208 return params;
1209 }
1210
1211 // Bitwise-AND of a vector of bits to get the combined validity.
BuildCombinedValidity(const DexVector & validities)1212 llvm::Value* LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector& validities) {
1213 llvm::IRBuilder<>* builder = ir_builder();
1214 LLVMTypes* types = generator_->types();
1215
1216 llvm::Value* isValid = types->true_constant();
1217 for (auto& dex : validities) {
1218 dex->Accept(*this);
1219 isValid = builder->CreateAnd(isValid, result()->data(), "validityBitAnd");
1220 }
1221 ADD_VISITOR_TRACE("combined validity is %T", isValid);
1222 return isValid;
1223 }
1224
GetBufferReference(int idx,BufferType buffer_type,FieldPtr field)1225 llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buffer_type,
1226 FieldPtr field) {
1227 llvm::IRBuilder<>* builder = ir_builder();
1228
1229 // Switch to the entry block to create a reference.
1230 llvm::BasicBlock* saved_block = builder->GetInsertBlock();
1231 builder->SetInsertPoint(entry_block_);
1232
1233 llvm::Value* slot_ref = nullptr;
1234 switch (buffer_type) {
1235 case kBufferTypeValidity:
1236 slot_ref = generator_->GetValidityReference(arg_addrs_, idx, field);
1237 break;
1238
1239 case kBufferTypeData:
1240 slot_ref = generator_->GetDataReference(arg_addrs_, idx, field);
1241 break;
1242
1243 case kBufferTypeOffsets:
1244 slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field);
1245 break;
1246 }
1247
1248 // Revert to the saved block.
1249 builder->SetInsertPoint(saved_block);
1250 return slot_ref;
1251 }
1252
GetSliceOffset(int idx)1253 llvm::Value* LLVMGenerator::Visitor::GetSliceOffset(int idx) {
1254 return slice_offsets_[idx];
1255 }
1256
GetLocalBitMapReference(int idx)1257 llvm::Value* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) {
1258 llvm::IRBuilder<>* builder = ir_builder();
1259
1260 // Switch to the entry block to create a reference.
1261 llvm::BasicBlock* saved_block = builder->GetInsertBlock();
1262 builder->SetInsertPoint(entry_block_);
1263
1264 llvm::Value* slot_ref = generator_->GetLocalBitMapReference(arg_local_bitmaps_, idx);
1265
1266 // Revert to the saved block.
1267 builder->SetInsertPoint(saved_block);
1268 return slot_ref;
1269 }
1270
1271 /// The local bitmap is pre-filled with 1s. Clear only if invalid.
ClearLocalBitMapIfNotValid(int local_bitmap_idx,llvm::Value * is_valid)1272 void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx,
1273 llvm::Value* is_valid) {
1274 llvm::Value* slot_ref = GetLocalBitMapReference(local_bitmap_idx);
1275 generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid);
1276 }
1277
1278 // Hooks for tracing/printfs.
1279 //
1280 // replace %T with the type-specific format specifier.
1281 // For some reason, float/double literals are getting lost when printing with the generic
1282 // printf. so, use a wrapper instead.
ReplaceFormatInTrace(const std::string & in_msg,llvm::Value * value,std::string * print_fn)1283 std::string LLVMGenerator::ReplaceFormatInTrace(const std::string& in_msg,
1284 llvm::Value* value,
1285 std::string* print_fn) {
1286 std::string msg = in_msg;
1287 std::size_t pos = msg.find("%T");
1288 if (pos == std::string::npos) {
1289 DCHECK(0);
1290 return msg;
1291 }
1292
1293 llvm::Type* type = value->getType();
1294 const char* fmt = "";
1295 if (type->isIntegerTy(1) || type->isIntegerTy(8) || type->isIntegerTy(16) ||
1296 type->isIntegerTy(32)) {
1297 fmt = "%d";
1298 } else if (type->isIntegerTy(64)) {
1299 // bigint
1300 fmt = "%lld";
1301 } else if (type->isFloatTy()) {
1302 // float
1303 fmt = "%f";
1304 *print_fn = "print_float";
1305 } else if (type->isDoubleTy()) {
1306 // float
1307 fmt = "%lf";
1308 *print_fn = "print_double";
1309 } else if (type->isPointerTy()) {
1310 // string
1311 fmt = "%s";
1312 } else {
1313 DCHECK(0);
1314 }
1315 msg.replace(pos, 2, fmt);
1316 return msg;
1317 }
1318
AddTrace(const std::string & msg,llvm::Value * value)1319 void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) {
1320 if (!enable_ir_traces_) {
1321 return;
1322 }
1323
1324 std::string dmsg = "IR_TRACE:: " + msg + "\n";
1325 std::string print_fn_name = "printf";
1326 if (value != nullptr) {
1327 dmsg = ReplaceFormatInTrace(dmsg, value, &print_fn_name);
1328 }
1329 trace_strings_.push_back(dmsg);
1330
1331 // cast this to an llvm pointer.
1332 const char* str = trace_strings_.back().c_str();
1333 llvm::Constant* str_int_cast = types()->i64_constant((int64_t)str);
1334 llvm::Constant* str_ptr_cast =
1335 llvm::ConstantExpr::getIntToPtr(str_int_cast, types()->i8_ptr_type());
1336
1337 std::vector<llvm::Value*> args;
1338 args.push_back(str_ptr_cast);
1339 if (value != nullptr) {
1340 args.push_back(value);
1341 }
1342 AddFunctionCall(print_fn_name, types()->i32_type(), args);
1343 }
1344
1345 } // namespace gandiva
1346