1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <gtest/gtest.h>
19 #include "arrow/memory_pool.h"
20 #include "arrow/status.h"
21 
22 #include "gandiva/projector.h"
23 #include "gandiva/tests/test_util.h"
24 #include "gandiva/tree_expr_builder.h"
25 
26 namespace gandiva {
27 
28 using arrow::boolean;
29 using arrow::float32;
30 using arrow::int32;
31 
32 class TestIfExpr : public ::testing::Test {
33  public:
SetUp()34   void SetUp() { pool_ = arrow::default_memory_pool(); }
35 
36  protected:
37   arrow::MemoryPool* pool_;
38 };
39 
TEST_F(TestIfExpr,TestSimple)40 TEST_F(TestIfExpr, TestSimple) {
41   // schema for input fields
42   auto fielda = field("a", int32());
43   auto fieldb = field("b", int32());
44   auto schema = arrow::schema({fielda, fieldb});
45 
46   // output fields
47   auto field_result = field("res", int32());
48 
49   // build expression.
50   // if (a > b)
51   //   a
52   // else
53   //   b
54   auto node_a = TreeExprBuilder::MakeField(fielda);
55   auto node_b = TreeExprBuilder::MakeField(fieldb);
56   auto condition =
57       TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
58   auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32());
59 
60   auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
61 
62   // Build a projector for the expressions.
63   std::shared_ptr<Projector> projector;
64   auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
65   EXPECT_TRUE(status.ok());
66 
67   // Create a row-batch with some sample data
68   int num_records = 4;
69   auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
70   auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
71 
72   // expected output
73   auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true});
74 
75   // prepare input record batch
76   auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
77 
78   // Evaluate expression
79   arrow::ArrayVector outputs;
80   status = projector->Evaluate(*in_batch, pool_, &outputs);
81   EXPECT_TRUE(status.ok());
82 
83   // Validate results
84   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
85 }
86 
TEST_F(TestIfExpr,TestSimpleArithmetic)87 TEST_F(TestIfExpr, TestSimpleArithmetic) {
88   // schema for input fields
89   auto fielda = field("a", int32());
90   auto fieldb = field("b", int32());
91   auto schema = arrow::schema({fielda, fieldb});
92 
93   // output fields
94   auto field_result = field("res", int32());
95 
96   // build expression.
97   // if (a > b)
98   //   a + b
99   // else
100   //   a - b
101   auto node_a = TreeExprBuilder::MakeField(fielda);
102   auto node_b = TreeExprBuilder::MakeField(fieldb);
103   auto condition =
104       TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
105   auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
106   auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
107   auto if_node = TreeExprBuilder::MakeIf(condition, sum, sub, int32());
108 
109   auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
110 
111   // Build a projector for the expressions.
112   std::shared_ptr<Projector> projector;
113   auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
114   EXPECT_TRUE(status.ok());
115 
116   // Create a row-batch with some sample data
117   int num_records = 4;
118   auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
119   auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
120 
121   // expected output
122   auto exp = MakeArrowArrayInt32({15, -3, -35, 0}, {true, true, true, false});
123 
124   // prepare input record batch
125   auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
126 
127   // Evaluate expression
128   arrow::ArrayVector outputs;
129   status = projector->Evaluate(*in_batch, pool_, &outputs);
130   EXPECT_TRUE(status.ok());
131 
132   // Validate results
133   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
134 }
135 
TEST_F(TestIfExpr,TestNested)136 TEST_F(TestIfExpr, TestNested) {
137   // schema for input fields
138   auto fielda = field("a", int32());
139   auto fieldb = field("b", int32());
140   auto schema = arrow::schema({fielda, fieldb});
141 
142   // output fields
143   auto field_result = field("res", int32());
144 
145   // build expression.
146   // if (a > b)
147   //   a + b
148   // else if (a < b)
149   //   a - b
150   // else
151   //   a * b
152   auto node_a = TreeExprBuilder::MakeField(fielda);
153   auto node_b = TreeExprBuilder::MakeField(fieldb);
154   auto condition_gt =
155       TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
156   auto condition_lt =
157       TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
158   auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
159   auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
160   auto mult = TreeExprBuilder::MakeFunction("multiply", {node_a, node_b}, int32());
161   auto else_node = TreeExprBuilder::MakeIf(condition_lt, sub, mult, int32());
162   auto if_node = TreeExprBuilder::MakeIf(condition_gt, sum, else_node, int32());
163 
164   auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
165 
166   // Build a projector for the expressions.
167   std::shared_ptr<Projector> projector;
168   auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
169   EXPECT_TRUE(status.ok());
170 
171   // Create a row-batch with some sample data
172   int num_records = 4;
173   auto array0 = MakeArrowArrayInt32({10, 12, 15, 5}, {true, true, true, false});
174   auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
175 
176   // expected output
177   auto exp = MakeArrowArrayInt32({15, -3, 225, 0}, {true, true, true, false});
178 
179   // prepare input record batch
180   auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
181 
182   // Evaluate expression
183   arrow::ArrayVector outputs;
184   status = projector->Evaluate(*in_batch, pool_, &outputs);
185   EXPECT_TRUE(status.ok());
186 
187   // Validate results
188   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
189 }
190 
TEST_F(TestIfExpr,TestNestedInIf)191 TEST_F(TestIfExpr, TestNestedInIf) {
192   // schema for input fields
193   auto fielda = field("a", int32());
194   auto fieldb = field("b", int32());
195   auto fieldc = field("c", int32());
196   auto schema = arrow::schema({fielda, fieldb, fieldc});
197 
198   // output fields
199   auto field_result = field("res", int32());
200 
201   // build expression.
202   // if (a > 10)
203   //   if (a < 20)
204   //     a + b
205   //   else
206   //     b + c
207   // else
208   //   a + c
209   auto node_a = TreeExprBuilder::MakeField(fielda);
210   auto node_b = TreeExprBuilder::MakeField(fieldb);
211   auto node_c = TreeExprBuilder::MakeField(fieldc);
212 
213   auto literal_10 = TreeExprBuilder::MakeLiteral(10);
214   auto literal_20 = TreeExprBuilder::MakeLiteral(20);
215 
216   auto gt_10 =
217       TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_10}, boolean());
218   auto lt_20 =
219       TreeExprBuilder::MakeFunction("less_than", {node_a, literal_20}, boolean());
220   auto sum_ab = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
221   auto sum_bc = TreeExprBuilder::MakeFunction("add", {node_b, node_c}, int32());
222   auto sum_ac = TreeExprBuilder::MakeFunction("add", {node_a, node_c}, int32());
223 
224   auto if_lt_20 = TreeExprBuilder::MakeIf(lt_20, sum_ab, sum_bc, int32());
225   auto if_gt_10 = TreeExprBuilder::MakeIf(gt_10, if_lt_20, sum_ac, int32());
226 
227   auto expr = TreeExprBuilder::MakeExpression(if_gt_10, field_result);
228 
229   // Build a projector for the expressions.
230   std::shared_ptr<Projector> projector;
231   auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
232   EXPECT_TRUE(status.ok());
233 
234   // Create a row-batch with some sample data
235   int num_records = 6;
236   auto array_a =
237       MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
238   auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
239                                      {true, true, true, false, false, false});
240   auto array_c = MakeArrowArrayInt32({35, 45, 55, 35, 45, 55},
241                                      {true, true, true, false, false, false});
242 
243   // expected output
244   auto exp =
245       MakeArrowArrayInt32({55, 33, 60, 0, 0, 0}, {true, true, true, false, false, false});
246 
247   // prepare input record batch
248   auto in_batch =
249       arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
250 
251   // Evaluate expression
252   arrow::ArrayVector outputs;
253   status = projector->Evaluate(*in_batch, pool_, &outputs);
254   EXPECT_TRUE(status.ok());
255 
256   // Validate results
257   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
258 }
259 
TEST_F(TestIfExpr,TestNestedInCondition)260 TEST_F(TestIfExpr, TestNestedInCondition) {
261   // schema for input fields
262   auto fielda = field("a", int32());
263   auto fieldb = field("b", int32());
264   auto schema = arrow::schema({fielda, fieldb});
265 
266   // output fields
267   auto field_result = field("res", int32());
268 
269   // build expression.
270   // if (if (a > b) then true else if (a < b) false else null)
271   //   1
272   // else if !(if (a > b) then true else if (a < b) false else null)
273   //   2
274   // else
275   //   3
276   auto node_a = TreeExprBuilder::MakeField(fielda);
277   auto node_b = TreeExprBuilder::MakeField(fieldb);
278   auto literal_1 = TreeExprBuilder::MakeLiteral(1);
279   auto literal_2 = TreeExprBuilder::MakeLiteral(2);
280   auto literal_3 = TreeExprBuilder::MakeLiteral(3);
281   auto literal_true = TreeExprBuilder::MakeLiteral(true);
282   auto literal_false = TreeExprBuilder::MakeLiteral(false);
283   auto literal_null = TreeExprBuilder::MakeNull(boolean());
284 
285   auto a_gt_b =
286       TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
287   auto a_lt_b = TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
288   auto cond_else =
289       TreeExprBuilder::MakeIf(a_lt_b, literal_false, literal_null, boolean());
290   auto cond_if = TreeExprBuilder::MakeIf(a_gt_b, literal_true, cond_else, boolean());
291   auto not_cond_if = TreeExprBuilder::MakeFunction("not", {cond_if}, boolean());
292 
293   auto outer_else = TreeExprBuilder::MakeIf(not_cond_if, literal_2, literal_3, int32());
294   auto outer_if = TreeExprBuilder::MakeIf(cond_if, literal_1, outer_else, int32());
295   auto expr = TreeExprBuilder::MakeExpression(outer_if, field_result);
296 
297   // Build a projector for the expressions.
298   std::shared_ptr<Projector> projector;
299   auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
300   EXPECT_TRUE(status.ok());
301 
302   // Create a row-batch with some sample data
303   int num_records = 6;
304   auto array_a =
305       MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
306   auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
307                                      {true, true, true, false, false, false});
308   // expected output
309   auto exp =
310       MakeArrowArrayInt32({1, 2, 2, 3, 3, 3}, {true, true, true, true, true, true});
311 
312   // prepare input record batch
313   auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
314 
315   // Evaluate expression
316   arrow::ArrayVector outputs;
317   status = projector->Evaluate(*in_batch, pool_, &outputs);
318   EXPECT_TRUE(status.ok());
319 
320   // Validate results
321   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
322 }
323 
TEST_F(TestIfExpr,TestBigNested)324 TEST_F(TestIfExpr, TestBigNested) {
325   // schema for input fields
326   auto fielda = field("a", int32());
327   auto schema = arrow::schema({fielda});
328 
329   // output fields
330   auto field_result = field("res", int32());
331 
332   // build expression.
333   // if (a < 10)
334   //   10
335   // else if (a < 20)
336   //   20
337   // ..
338   // ..
339   // else if (a < 190)
340   //   190
341   // else
342   //   200
343   auto node_a = TreeExprBuilder::MakeField(fielda);
344   auto top_node = TreeExprBuilder::MakeLiteral(200);
345   for (int thresh = 190; thresh > 0; thresh -= 10) {
346     auto literal = TreeExprBuilder::MakeLiteral(thresh);
347     auto condition =
348         TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean());
349     auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32());
350     top_node = if_node;
351   }
352   auto expr = TreeExprBuilder::MakeExpression(top_node, field_result);
353 
354   // Build a projector for the expressions.
355   std::shared_ptr<Projector> projector;
356   auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
357   EXPECT_TRUE(status.ok());
358 
359   // Create a row-batch with some sample data
360   int num_records = 4;
361   auto array0 = MakeArrowArrayInt32({10, 102, 158, 302}, {true, true, true, true});
362 
363   // expected output
364   auto exp = MakeArrowArrayInt32({20, 110, 160, 200}, {true, true, true, true});
365 
366   // prepare input record batch
367   auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
368 
369   // Evaluate expression
370   arrow::ArrayVector outputs;
371   status = projector->Evaluate(*in_batch, pool_, &outputs);
372   EXPECT_TRUE(status.ok());
373 
374   // Validate results
375   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
376 }
377 
378 }  // namespace gandiva
379