1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <gtest/gtest.h>
19 #include "arrow/memory_pool.h"
20 #include "arrow/status.h"
21
22 #include "gandiva/projector.h"
23 #include "gandiva/tests/test_util.h"
24 #include "gandiva/tree_expr_builder.h"
25
26 namespace gandiva {
27
28 using arrow::boolean;
29 using arrow::float32;
30 using arrow::int32;
31
32 class TestIfExpr : public ::testing::Test {
33 public:
SetUp()34 void SetUp() { pool_ = arrow::default_memory_pool(); }
35
36 protected:
37 arrow::MemoryPool* pool_;
38 };
39
TEST_F(TestIfExpr,TestSimple)40 TEST_F(TestIfExpr, TestSimple) {
41 // schema for input fields
42 auto fielda = field("a", int32());
43 auto fieldb = field("b", int32());
44 auto schema = arrow::schema({fielda, fieldb});
45
46 // output fields
47 auto field_result = field("res", int32());
48
49 // build expression.
50 // if (a > b)
51 // a
52 // else
53 // b
54 auto node_a = TreeExprBuilder::MakeField(fielda);
55 auto node_b = TreeExprBuilder::MakeField(fieldb);
56 auto condition =
57 TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
58 auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32());
59
60 auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
61
62 // Build a projector for the expressions.
63 std::shared_ptr<Projector> projector;
64 auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
65 EXPECT_TRUE(status.ok());
66
67 // Create a row-batch with some sample data
68 int num_records = 4;
69 auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
70 auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
71
72 // expected output
73 auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true});
74
75 // prepare input record batch
76 auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
77
78 // Evaluate expression
79 arrow::ArrayVector outputs;
80 status = projector->Evaluate(*in_batch, pool_, &outputs);
81 EXPECT_TRUE(status.ok());
82
83 // Validate results
84 EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
85 }
86
TEST_F(TestIfExpr,TestSimpleArithmetic)87 TEST_F(TestIfExpr, TestSimpleArithmetic) {
88 // schema for input fields
89 auto fielda = field("a", int32());
90 auto fieldb = field("b", int32());
91 auto schema = arrow::schema({fielda, fieldb});
92
93 // output fields
94 auto field_result = field("res", int32());
95
96 // build expression.
97 // if (a > b)
98 // a + b
99 // else
100 // a - b
101 auto node_a = TreeExprBuilder::MakeField(fielda);
102 auto node_b = TreeExprBuilder::MakeField(fieldb);
103 auto condition =
104 TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
105 auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
106 auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
107 auto if_node = TreeExprBuilder::MakeIf(condition, sum, sub, int32());
108
109 auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
110
111 // Build a projector for the expressions.
112 std::shared_ptr<Projector> projector;
113 auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
114 EXPECT_TRUE(status.ok());
115
116 // Create a row-batch with some sample data
117 int num_records = 4;
118 auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
119 auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
120
121 // expected output
122 auto exp = MakeArrowArrayInt32({15, -3, -35, 0}, {true, true, true, false});
123
124 // prepare input record batch
125 auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
126
127 // Evaluate expression
128 arrow::ArrayVector outputs;
129 status = projector->Evaluate(*in_batch, pool_, &outputs);
130 EXPECT_TRUE(status.ok());
131
132 // Validate results
133 EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
134 }
135
TEST_F(TestIfExpr,TestNested)136 TEST_F(TestIfExpr, TestNested) {
137 // schema for input fields
138 auto fielda = field("a", int32());
139 auto fieldb = field("b", int32());
140 auto schema = arrow::schema({fielda, fieldb});
141
142 // output fields
143 auto field_result = field("res", int32());
144
145 // build expression.
146 // if (a > b)
147 // a + b
148 // else if (a < b)
149 // a - b
150 // else
151 // a * b
152 auto node_a = TreeExprBuilder::MakeField(fielda);
153 auto node_b = TreeExprBuilder::MakeField(fieldb);
154 auto condition_gt =
155 TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
156 auto condition_lt =
157 TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
158 auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
159 auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
160 auto mult = TreeExprBuilder::MakeFunction("multiply", {node_a, node_b}, int32());
161 auto else_node = TreeExprBuilder::MakeIf(condition_lt, sub, mult, int32());
162 auto if_node = TreeExprBuilder::MakeIf(condition_gt, sum, else_node, int32());
163
164 auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
165
166 // Build a projector for the expressions.
167 std::shared_ptr<Projector> projector;
168 auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
169 EXPECT_TRUE(status.ok());
170
171 // Create a row-batch with some sample data
172 int num_records = 4;
173 auto array0 = MakeArrowArrayInt32({10, 12, 15, 5}, {true, true, true, false});
174 auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
175
176 // expected output
177 auto exp = MakeArrowArrayInt32({15, -3, 225, 0}, {true, true, true, false});
178
179 // prepare input record batch
180 auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
181
182 // Evaluate expression
183 arrow::ArrayVector outputs;
184 status = projector->Evaluate(*in_batch, pool_, &outputs);
185 EXPECT_TRUE(status.ok());
186
187 // Validate results
188 EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
189 }
190
TEST_F(TestIfExpr,TestNestedInIf)191 TEST_F(TestIfExpr, TestNestedInIf) {
192 // schema for input fields
193 auto fielda = field("a", int32());
194 auto fieldb = field("b", int32());
195 auto fieldc = field("c", int32());
196 auto schema = arrow::schema({fielda, fieldb, fieldc});
197
198 // output fields
199 auto field_result = field("res", int32());
200
201 // build expression.
202 // if (a > 10)
203 // if (a < 20)
204 // a + b
205 // else
206 // b + c
207 // else
208 // a + c
209 auto node_a = TreeExprBuilder::MakeField(fielda);
210 auto node_b = TreeExprBuilder::MakeField(fieldb);
211 auto node_c = TreeExprBuilder::MakeField(fieldc);
212
213 auto literal_10 = TreeExprBuilder::MakeLiteral(10);
214 auto literal_20 = TreeExprBuilder::MakeLiteral(20);
215
216 auto gt_10 =
217 TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_10}, boolean());
218 auto lt_20 =
219 TreeExprBuilder::MakeFunction("less_than", {node_a, literal_20}, boolean());
220 auto sum_ab = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
221 auto sum_bc = TreeExprBuilder::MakeFunction("add", {node_b, node_c}, int32());
222 auto sum_ac = TreeExprBuilder::MakeFunction("add", {node_a, node_c}, int32());
223
224 auto if_lt_20 = TreeExprBuilder::MakeIf(lt_20, sum_ab, sum_bc, int32());
225 auto if_gt_10 = TreeExprBuilder::MakeIf(gt_10, if_lt_20, sum_ac, int32());
226
227 auto expr = TreeExprBuilder::MakeExpression(if_gt_10, field_result);
228
229 // Build a projector for the expressions.
230 std::shared_ptr<Projector> projector;
231 auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
232 EXPECT_TRUE(status.ok());
233
234 // Create a row-batch with some sample data
235 int num_records = 6;
236 auto array_a =
237 MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
238 auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
239 {true, true, true, false, false, false});
240 auto array_c = MakeArrowArrayInt32({35, 45, 55, 35, 45, 55},
241 {true, true, true, false, false, false});
242
243 // expected output
244 auto exp =
245 MakeArrowArrayInt32({55, 33, 60, 0, 0, 0}, {true, true, true, false, false, false});
246
247 // prepare input record batch
248 auto in_batch =
249 arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
250
251 // Evaluate expression
252 arrow::ArrayVector outputs;
253 status = projector->Evaluate(*in_batch, pool_, &outputs);
254 EXPECT_TRUE(status.ok());
255
256 // Validate results
257 EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
258 }
259
TEST_F(TestIfExpr,TestNestedInCondition)260 TEST_F(TestIfExpr, TestNestedInCondition) {
261 // schema for input fields
262 auto fielda = field("a", int32());
263 auto fieldb = field("b", int32());
264 auto schema = arrow::schema({fielda, fieldb});
265
266 // output fields
267 auto field_result = field("res", int32());
268
269 // build expression.
270 // if (if (a > b) then true else if (a < b) false else null)
271 // 1
272 // else if !(if (a > b) then true else if (a < b) false else null)
273 // 2
274 // else
275 // 3
276 auto node_a = TreeExprBuilder::MakeField(fielda);
277 auto node_b = TreeExprBuilder::MakeField(fieldb);
278 auto literal_1 = TreeExprBuilder::MakeLiteral(1);
279 auto literal_2 = TreeExprBuilder::MakeLiteral(2);
280 auto literal_3 = TreeExprBuilder::MakeLiteral(3);
281 auto literal_true = TreeExprBuilder::MakeLiteral(true);
282 auto literal_false = TreeExprBuilder::MakeLiteral(false);
283 auto literal_null = TreeExprBuilder::MakeNull(boolean());
284
285 auto a_gt_b =
286 TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
287 auto a_lt_b = TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
288 auto cond_else =
289 TreeExprBuilder::MakeIf(a_lt_b, literal_false, literal_null, boolean());
290 auto cond_if = TreeExprBuilder::MakeIf(a_gt_b, literal_true, cond_else, boolean());
291 auto not_cond_if = TreeExprBuilder::MakeFunction("not", {cond_if}, boolean());
292
293 auto outer_else = TreeExprBuilder::MakeIf(not_cond_if, literal_2, literal_3, int32());
294 auto outer_if = TreeExprBuilder::MakeIf(cond_if, literal_1, outer_else, int32());
295 auto expr = TreeExprBuilder::MakeExpression(outer_if, field_result);
296
297 // Build a projector for the expressions.
298 std::shared_ptr<Projector> projector;
299 auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
300 EXPECT_TRUE(status.ok());
301
302 // Create a row-batch with some sample data
303 int num_records = 6;
304 auto array_a =
305 MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
306 auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
307 {true, true, true, false, false, false});
308 // expected output
309 auto exp =
310 MakeArrowArrayInt32({1, 2, 2, 3, 3, 3}, {true, true, true, true, true, true});
311
312 // prepare input record batch
313 auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
314
315 // Evaluate expression
316 arrow::ArrayVector outputs;
317 status = projector->Evaluate(*in_batch, pool_, &outputs);
318 EXPECT_TRUE(status.ok());
319
320 // Validate results
321 EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
322 }
323
TEST_F(TestIfExpr,TestBigNested)324 TEST_F(TestIfExpr, TestBigNested) {
325 // schema for input fields
326 auto fielda = field("a", int32());
327 auto schema = arrow::schema({fielda});
328
329 // output fields
330 auto field_result = field("res", int32());
331
332 // build expression.
333 // if (a < 10)
334 // 10
335 // else if (a < 20)
336 // 20
337 // ..
338 // ..
339 // else if (a < 190)
340 // 190
341 // else
342 // 200
343 auto node_a = TreeExprBuilder::MakeField(fielda);
344 auto top_node = TreeExprBuilder::MakeLiteral(200);
345 for (int thresh = 190; thresh > 0; thresh -= 10) {
346 auto literal = TreeExprBuilder::MakeLiteral(thresh);
347 auto condition =
348 TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean());
349 auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32());
350 top_node = if_node;
351 }
352 auto expr = TreeExprBuilder::MakeExpression(top_node, field_result);
353
354 // Build a projector for the expressions.
355 std::shared_ptr<Projector> projector;
356 auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
357 EXPECT_TRUE(status.ok());
358
359 // Create a row-batch with some sample data
360 int num_records = 4;
361 auto array0 = MakeArrowArrayInt32({10, 102, 158, 302}, {true, true, true, true});
362
363 // expected output
364 auto exp = MakeArrowArrayInt32({20, 110, 160, 200}, {true, true, true, true});
365
366 // prepare input record batch
367 auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
368
369 // Evaluate expression
370 arrow::ArrayVector outputs;
371 status = projector->Evaluate(*in_batch, pool_, &outputs);
372 EXPECT_TRUE(status.ok());
373
374 // Validate results
375 EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
376 }
377
378 } // namespace gandiva
379