1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "gandiva/engine.h"
19
20 #include <gtest/gtest.h>
21 #include <functional>
22 #include "gandiva/llvm_types.h"
23 #include "gandiva/tests/test_util.h"
24
25 namespace gandiva {
26
27 typedef int64_t (*add_vector_func_t)(int64_t* data, int n);
28
29 class TestEngine : public ::testing::Test {
30 protected:
BuildVecAdd(Engine * engine)31 llvm::Function* BuildVecAdd(Engine* engine) {
32 auto types = engine->types();
33 llvm::IRBuilder<>* builder = engine->ir_builder();
34 llvm::LLVMContext* context = engine->context();
35
36 // Create fn prototype :
37 // int64_t add_longs(int64_t *elements, int32_t nelements)
38 std::vector<llvm::Type*> arguments;
39 arguments.push_back(types->i64_ptr_type());
40 arguments.push_back(types->i32_type());
41 llvm::FunctionType* prototype =
42 llvm::FunctionType::get(types->i64_type(), arguments, false /*isVarArg*/);
43
44 // Create fn
45 std::string func_name = "add_longs";
46 engine->AddFunctionToCompile(func_name);
47 llvm::Function* fn = llvm::Function::Create(
48 prototype, llvm::GlobalValue::ExternalLinkage, func_name, engine->module());
49 assert(fn != nullptr);
50
51 // Name the arguments
52 llvm::Function::arg_iterator args = fn->arg_begin();
53 llvm::Value* arg_elements = &*args;
54 arg_elements->setName("elements");
55 ++args;
56 llvm::Value* arg_nelements = &*args;
57 arg_nelements->setName("nelements");
58 ++args;
59
60 llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context, "entry", fn);
61 llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context, "loop", fn);
62 llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context, "exit", fn);
63
64 // Loop entry
65 builder->SetInsertPoint(loop_entry);
66 builder->CreateBr(loop_body);
67
68 // Loop body
69 builder->SetInsertPoint(loop_body);
70
71 llvm::PHINode* loop_var = builder->CreatePHI(types->i32_type(), 2, "loop_var");
72 llvm::PHINode* sum = builder->CreatePHI(types->i64_type(), 2, "sum");
73
74 loop_var->addIncoming(types->i32_constant(0), loop_entry);
75 sum->addIncoming(types->i64_constant(0), loop_entry);
76
77 // setup loop PHI
78 llvm::Value* loop_update =
79 builder->CreateAdd(loop_var, types->i32_constant(1), "loop_var+1");
80 loop_var->addIncoming(loop_update, loop_body);
81
82 // get the current value
83 llvm::Value* offset = builder->CreateGEP(arg_elements, loop_var, "offset");
84 llvm::Value* current_value = builder->CreateLoad(offset, "value");
85
86 // setup sum PHI
87 llvm::Value* sum_update = builder->CreateAdd(sum, current_value, "sum+ith");
88 sum->addIncoming(sum_update, loop_body);
89
90 // check loop_var
91 llvm::Value* loop_var_check =
92 builder->CreateICmpSLT(loop_update, arg_nelements, "loop_var < nrec");
93 builder->CreateCondBr(loop_var_check, loop_body, loop_exit);
94
95 // Loop exit
96 builder->SetInsertPoint(loop_exit);
97 builder->CreateRet(sum_update);
98 return fn;
99 }
100
BuildEngine()101 void BuildEngine() { ASSERT_OK(Engine::Make(TestConfiguration(), &engine)); }
102
103 std::unique_ptr<Engine> engine;
104 std::shared_ptr<Configuration> configuration = TestConfiguration();
105 };
106
TEST_F(TestEngine,TestAddUnoptimised)107 TEST_F(TestEngine, TestAddUnoptimised) {
108 configuration->set_optimize(false);
109 BuildEngine();
110
111 llvm::Function* ir_func = BuildVecAdd(engine.get());
112 ASSERT_OK(engine->FinalizeModule());
113 auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func));
114
115 int64_t my_array[] = {1, 3, -5, 8, 10};
116 EXPECT_EQ(add_func(my_array, 5), 17);
117 }
118
TEST_F(TestEngine,TestAddOptimised)119 TEST_F(TestEngine, TestAddOptimised) {
120 configuration->set_optimize(true);
121 BuildEngine();
122
123 llvm::Function* ir_func = BuildVecAdd(engine.get());
124 ASSERT_OK(engine->FinalizeModule());
125 auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func));
126
127 int64_t my_array[] = {1, 3, -5, 8, 10};
128 EXPECT_EQ(add_func(my_array, 5), 17);
129 }
130
131 } // namespace gandiva
132