1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "gandiva/engine.h"
19 
20 #include <gtest/gtest.h>
21 #include <functional>
22 #include "gandiva/llvm_types.h"
23 #include "gandiva/tests/test_util.h"
24 
25 namespace gandiva {
26 
27 typedef int64_t (*add_vector_func_t)(int64_t* data, int n);
28 
29 class TestEngine : public ::testing::Test {
30  protected:
BuildVecAdd(Engine * engine)31   llvm::Function* BuildVecAdd(Engine* engine) {
32     auto types = engine->types();
33     llvm::IRBuilder<>* builder = engine->ir_builder();
34     llvm::LLVMContext* context = engine->context();
35 
36     // Create fn prototype :
37     //   int64_t add_longs(int64_t *elements, int32_t nelements)
38     std::vector<llvm::Type*> arguments;
39     arguments.push_back(types->i64_ptr_type());
40     arguments.push_back(types->i32_type());
41     llvm::FunctionType* prototype =
42         llvm::FunctionType::get(types->i64_type(), arguments, false /*isVarArg*/);
43 
44     // Create fn
45     std::string func_name = "add_longs";
46     engine->AddFunctionToCompile(func_name);
47     llvm::Function* fn = llvm::Function::Create(
48         prototype, llvm::GlobalValue::ExternalLinkage, func_name, engine->module());
49     assert(fn != nullptr);
50 
51     // Name the arguments
52     llvm::Function::arg_iterator args = fn->arg_begin();
53     llvm::Value* arg_elements = &*args;
54     arg_elements->setName("elements");
55     ++args;
56     llvm::Value* arg_nelements = &*args;
57     arg_nelements->setName("nelements");
58     ++args;
59 
60     llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context, "entry", fn);
61     llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context, "loop", fn);
62     llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context, "exit", fn);
63 
64     // Loop entry
65     builder->SetInsertPoint(loop_entry);
66     builder->CreateBr(loop_body);
67 
68     // Loop body
69     builder->SetInsertPoint(loop_body);
70 
71     llvm::PHINode* loop_var = builder->CreatePHI(types->i32_type(), 2, "loop_var");
72     llvm::PHINode* sum = builder->CreatePHI(types->i64_type(), 2, "sum");
73 
74     loop_var->addIncoming(types->i32_constant(0), loop_entry);
75     sum->addIncoming(types->i64_constant(0), loop_entry);
76 
77     // setup loop PHI
78     llvm::Value* loop_update =
79         builder->CreateAdd(loop_var, types->i32_constant(1), "loop_var+1");
80     loop_var->addIncoming(loop_update, loop_body);
81 
82     // get the current value
83     llvm::Value* offset = builder->CreateGEP(arg_elements, loop_var, "offset");
84     llvm::Value* current_value = builder->CreateLoad(offset, "value");
85 
86     // setup sum PHI
87     llvm::Value* sum_update = builder->CreateAdd(sum, current_value, "sum+ith");
88     sum->addIncoming(sum_update, loop_body);
89 
90     // check loop_var
91     llvm::Value* loop_var_check =
92         builder->CreateICmpSLT(loop_update, arg_nelements, "loop_var < nrec");
93     builder->CreateCondBr(loop_var_check, loop_body, loop_exit);
94 
95     // Loop exit
96     builder->SetInsertPoint(loop_exit);
97     builder->CreateRet(sum_update);
98     return fn;
99   }
100 
BuildEngine()101   void BuildEngine() { ASSERT_OK(Engine::Make(TestConfiguration(), &engine)); }
102 
103   std::unique_ptr<Engine> engine;
104   std::shared_ptr<Configuration> configuration = TestConfiguration();
105 };
106 
TEST_F(TestEngine,TestAddUnoptimised)107 TEST_F(TestEngine, TestAddUnoptimised) {
108   configuration->set_optimize(false);
109   BuildEngine();
110 
111   llvm::Function* ir_func = BuildVecAdd(engine.get());
112   ASSERT_OK(engine->FinalizeModule());
113   auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func));
114 
115   int64_t my_array[] = {1, 3, -5, 8, 10};
116   EXPECT_EQ(add_func(my_array, 5), 17);
117 }
118 
TEST_F(TestEngine,TestAddOptimised)119 TEST_F(TestEngine, TestAddOptimised) {
120   configuration->set_optimize(true);
121   BuildEngine();
122 
123   llvm::Function* ir_func = BuildVecAdd(engine.get());
124   ASSERT_OK(engine->FinalizeModule());
125   auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func));
126 
127   int64_t my_array[] = {1, 3, -5, 8, 10};
128   EXPECT_EQ(add_func(my_array, 5), 17);
129 }
130 
131 }  // namespace gandiva
132