1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <sstream>
19 
20 #include <gtest/gtest.h>
21 #include "arrow/memory_pool.h"
22 #include "arrow/status.h"
23 
24 #include "gandiva/decimal_scalar.h"
25 #include "gandiva/decimal_type_util.h"
26 #include "gandiva/projector.h"
27 #include "gandiva/tests/test_util.h"
28 #include "gandiva/tree_expr_builder.h"
29 
30 using arrow::Decimal128;
31 
32 namespace gandiva {
33 
34 #define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual)                                \
35   EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" << (y).ToString() \
36                               << ")"                                                     \
37                               << " expected : " << (expected).ToString()                 \
38                               << " actual : " << (actual).ToString();
39 
decimal_literal(const char * value,int precision,int scale)40 DecimalScalar128 decimal_literal(const char* value, int precision, int scale) {
41   std::string value_string = std::string(value);
42   return DecimalScalar128(value_string, precision, scale);
43 }
44 
45 class TestDecimalOps : public ::testing::Test {
46  public:
SetUp()47   void SetUp() { pool_ = arrow::default_memory_pool(); }
48 
49   ArrayPtr MakeDecimalVector(const DecimalScalar128& in);
50 
51   void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x,
52               const DecimalScalar128& y, const DecimalScalar128& expected);
53 
AddAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)54   void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
55                     const DecimalScalar128& expected) {
56     Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected);
57   }
58 
SubtractAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)59   void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
60                          const DecimalScalar128& expected) {
61     Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected);
62   }
63 
MultiplyAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)64   void MultiplyAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
65                          const DecimalScalar128& expected) {
66     Verify(DecimalTypeUtil::kOpMultiply, "multiply", x, y, expected);
67   }
68 
DivideAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)69   void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
70                        const DecimalScalar128& expected) {
71     Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected);
72   }
73 
ModAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)74   void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
75                     const DecimalScalar128& expected) {
76     Verify(DecimalTypeUtil::kOpMod, "mod", x, y, expected);
77   }
78 
79  protected:
80   arrow::MemoryPool* pool_;
81 };
82 
MakeDecimalVector(const DecimalScalar128 & in)83 ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) {
84   std::vector<arrow::Decimal128> ret;
85 
86   Decimal128 decimal_value = in.value();
87 
88   auto decimal_type = std::make_shared<arrow::Decimal128Type>(in.precision(), in.scale());
89   return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true});
90 }
91 
Verify(DecimalTypeUtil::Op op,const std::string & function,const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)92 void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function,
93                             const DecimalScalar128& x, const DecimalScalar128& y,
94                             const DecimalScalar128& expected) {
95   auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
96   auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
97   auto field_x = field("x", x_type);
98   auto field_y = field("y", y_type);
99   auto schema = arrow::schema({field_x, field_y});
100 
101   Decimal128TypePtr output_type;
102   auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type);
103   ARROW_EXPECT_OK(status);
104 
105   // output fields
106   auto res = field("res", output_type);
107 
108   // build expression : x op y
109   auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y}, res);
110 
111   // Build a projector for the expression.
112   std::shared_ptr<Projector> projector;
113   status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
114   ARROW_EXPECT_OK(status);
115 
116   // Create a row-batch with some sample data
117   auto array_a = MakeDecimalVector(x);
118   auto array_b = MakeDecimalVector(y);
119 
120   // prepare input record batch
121   auto in_batch = arrow::RecordBatch::Make(schema, 1 /*num_records*/, {array_a, array_b});
122 
123   // Evaluate expression
124   arrow::ArrayVector outputs;
125   status = projector->Evaluate(*in_batch, pool_, &outputs);
126   ARROW_EXPECT_OK(status);
127 
128   // Validate results
129   auto out_array = dynamic_cast<arrow::Decimal128Array*>(outputs[0].get());
130   const Decimal128 out_value(out_array->GetValue(0));
131 
132   auto dtype = dynamic_cast<arrow::Decimal128Type*>(out_array->type().get());
133   std::string value_string = out_value.ToString(0);
134   DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()};
135 
136   EXPECT_DECIMAL_RESULT(function, x, y, expected, actual);
137 }
138 
TEST_F(TestDecimalOps,TestAdd)139 TEST_F(TestDecimalOps, TestAdd) {
140   // fast-path
141   AddAndVerify(decimal_literal("201", 30, 3),   // x
142                decimal_literal("301", 30, 3),   // y
143                decimal_literal("502", 31, 3));  // expected
144 
145   AddAndVerify(decimal_literal("201", 30, 3),    // x
146                decimal_literal("301", 30, 2),    // y
147                decimal_literal("3211", 32, 3));  // expected
148 
149   AddAndVerify(decimal_literal("201", 30, 3),    // x
150                decimal_literal("301", 30, 4),    // y
151                decimal_literal("2311", 32, 4));  // expected
152 
153   // max precision, but no overflow
154   AddAndVerify(decimal_literal("201", 38, 3),   // x
155                decimal_literal("301", 38, 3),   // y
156                decimal_literal("502", 38, 3));  // expected
157 
158   AddAndVerify(decimal_literal("201", 38, 3),    // x
159                decimal_literal("301", 38, 2),    // y
160                decimal_literal("3211", 38, 3));  // expected
161 
162   AddAndVerify(decimal_literal("201", 38, 3),    // x
163                decimal_literal("301", 38, 4),    // y
164                decimal_literal("2311", 38, 4));  // expected
165 
166   AddAndVerify(decimal_literal("201", 38, 3),      // x
167                decimal_literal("301", 38, 7),      // y
168                decimal_literal("201030", 38, 6));  // expected
169 
170   AddAndVerify(decimal_literal("1201", 38, 3),   // x
171                decimal_literal("1801", 38, 3),   // y
172                decimal_literal("3002", 38, 3));  // carry-over from fractional
173 
174   // max precision
175   AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5),  // x
176                decimal_literal("100", 38, 7),                                     // y
177                decimal_literal("99999999999999999999999999999990000010", 38, 6));
178 
179   AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5),  // x
180                decimal_literal("100", 38, 7),                                      // y
181                decimal_literal("-99999999999999999999999999999989999990", 38, 6));
182 
183   AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5),  // x
184                decimal_literal("-100", 38, 7),                                    // y
185                decimal_literal("99999999999999999999999999999989999990", 38, 6));
186 
187   AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5),  // x
188                decimal_literal("-100", 38, 7),                                     // y
189                decimal_literal("-99999999999999999999999999999990000010", 38, 6));
190 
191   AddAndVerify(decimal_literal("09999999999999999999999999999999999999", 38, 6),  // x
192                decimal_literal("89999999999999999999999999999999999999", 38, 7),  // y
193                decimal_literal("18999999999999999999999999999999999999", 38, 6));
194 
195   // Both -ve
196   AddAndVerify(decimal_literal("-201", 30, 3),    // x
197                decimal_literal("-301", 30, 2),    // y
198                decimal_literal("-3211", 32, 3));  // expected
199 
200   AddAndVerify(decimal_literal("-201", 38, 3),    // x
201                decimal_literal("-301", 38, 4),    // y
202                decimal_literal("-2311", 38, 4));  // expected
203 
204   // Mix of +ve and -ve
205   AddAndVerify(decimal_literal("-201", 30, 3),   // x
206                decimal_literal("301", 30, 2),    // y
207                decimal_literal("2809", 32, 3));  // expected
208 
209   AddAndVerify(decimal_literal("-201", 38, 3),    // x
210                decimal_literal("301", 38, 4),     // y
211                decimal_literal("-1709", 38, 4));  // expected
212 
213   AddAndVerify(decimal_literal("201", 38, 3),      // x
214                decimal_literal("-301", 38, 7),     // y
215                decimal_literal("200970", 38, 6));  // expected
216 
217   AddAndVerify(decimal_literal("-1901", 38, 4),  // x
218                decimal_literal("1801", 38, 4),   // y
219                decimal_literal("-100", 38, 4));  // expected
220 
221   AddAndVerify(decimal_literal("1801", 38, 4),   // x
222                decimal_literal("-1901", 38, 4),  // y
223                decimal_literal("-100", 38, 4));  // expected
224 
225   // rounding +ve
226   AddAndVerify(decimal_literal("1000999", 38, 6),   // x
227                decimal_literal("10000999", 38, 7),  // y
228                decimal_literal("2001099", 38, 6));
229 
230   AddAndVerify(decimal_literal("1000999", 38, 6),   // x
231                decimal_literal("10000995", 38, 7),  // y
232                decimal_literal("2001099", 38, 6));
233 
234   AddAndVerify(decimal_literal("1000999", 38, 6),   // x
235                decimal_literal("10000992", 38, 7),  // y
236                decimal_literal("2001098", 38, 6));
237 
238   // rounding -ve
239   AddAndVerify(decimal_literal("-1000999", 38, 6),   // x
240                decimal_literal("-10000999", 38, 7),  // y
241                decimal_literal("-2001099", 38, 6));
242 
243   AddAndVerify(decimal_literal("-1000999", 38, 6),   // x
244                decimal_literal("-10000995", 38, 7),  // y
245                decimal_literal("-2001099", 38, 6));
246 
247   AddAndVerify(decimal_literal("-1000999", 38, 6),   // x
248                decimal_literal("-10000992", 38, 7),  // y
249                decimal_literal("-2001098", 38, 6));
250 }
251 
252 // subtract is a wrapper over add. so, minimal tests are sufficient.
TEST_F(TestDecimalOps,TestSubtract)253 TEST_F(TestDecimalOps, TestSubtract) {
254   // fast-path
255   SubtractAndVerify(decimal_literal("201", 30, 3),    // x
256                     decimal_literal("301", 30, 3),    // y
257                     decimal_literal("-100", 31, 3));  // expected
258 
259   // max precision
260   SubtractAndVerify(
261       decimal_literal("09999999999999999999999999999999000000", 38, 5),  // x
262       decimal_literal("100", 38, 7),                                     // y
263       decimal_literal("99999999999999999999999999999989999990", 38, 6));
264 
265   // Mix of +ve and -ve
266   SubtractAndVerify(decimal_literal("-201", 30, 3),    // x
267                     decimal_literal("301", 30, 2),     // y
268                     decimal_literal("-3211", 32, 3));  // expected
269 }
270 
271 // Lots of unit tests for multiply/divide/mod in decimal_ops_test.cc. So, keeping these
272 // basic.
TEST_F(TestDecimalOps,TestMultiply)273 TEST_F(TestDecimalOps, TestMultiply) {
274   // fast-path
275   MultiplyAndVerify(decimal_literal("201", 10, 3),     // x
276                     decimal_literal("301", 10, 2),     // y
277                     decimal_literal("60501", 21, 5));  // expected
278 
279   // max precision
280   MultiplyAndVerify(DecimalScalar128(std::string(35, '9'), 38, 20),  // x
281                     DecimalScalar128(std::string(36, '9'), 38, 20),  // x
282                     DecimalScalar128("9999999999999999999999999999999999890", 38, 6));
283 }
284 
TEST_F(TestDecimalOps,TestDivide)285 TEST_F(TestDecimalOps, TestDivide) {
286   DivideAndVerify(decimal_literal("201", 10, 3),              // x
287                   decimal_literal("301", 10, 2),              // y
288                   decimal_literal("6677740863787", 23, 14));  // expected
289 
290   DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20),  // x
291                   DecimalScalar128(std::string(35, '9'), 38, 20),  // x
292                   DecimalScalar128("1000000000", 38, 6));
293 }
294 
TEST_F(TestDecimalOps,TestMod)295 TEST_F(TestDecimalOps, TestMod) {
296   ModAndVerify(decimal_literal("201", 20, 2),   // x
297                decimal_literal("301", 20, 3),   // y
298                decimal_literal("204", 20, 3));  // expected
299 
300   ModAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20),  // x
301                DecimalScalar128(std::string(35, '9'), 38, 21),  // x
302                DecimalScalar128("9990", 38, 21));
303 }
304 
305 }  // namespace gandiva
306