1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <sstream>
19
20 #include <gtest/gtest.h>
21 #include "arrow/memory_pool.h"
22 #include "arrow/status.h"
23
24 #include "gandiva/decimal_scalar.h"
25 #include "gandiva/decimal_type_util.h"
26 #include "gandiva/projector.h"
27 #include "gandiva/tests/test_util.h"
28 #include "gandiva/tree_expr_builder.h"
29
30 using arrow::Decimal128;
31
32 namespace gandiva {
33
34 #define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual) \
35 EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" << (y).ToString() \
36 << ")" \
37 << " expected : " << (expected).ToString() \
38 << " actual : " << (actual).ToString();
39
decimal_literal(const char * value,int precision,int scale)40 DecimalScalar128 decimal_literal(const char* value, int precision, int scale) {
41 std::string value_string = std::string(value);
42 return DecimalScalar128(value_string, precision, scale);
43 }
44
45 class TestDecimalOps : public ::testing::Test {
46 public:
SetUp()47 void SetUp() { pool_ = arrow::default_memory_pool(); }
48
49 ArrayPtr MakeDecimalVector(const DecimalScalar128& in);
50
51 void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x,
52 const DecimalScalar128& y, const DecimalScalar128& expected);
53
AddAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)54 void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
55 const DecimalScalar128& expected) {
56 Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected);
57 }
58
SubtractAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)59 void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
60 const DecimalScalar128& expected) {
61 Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected);
62 }
63
MultiplyAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)64 void MultiplyAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
65 const DecimalScalar128& expected) {
66 Verify(DecimalTypeUtil::kOpMultiply, "multiply", x, y, expected);
67 }
68
DivideAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)69 void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
70 const DecimalScalar128& expected) {
71 Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected);
72 }
73
ModAndVerify(const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)74 void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
75 const DecimalScalar128& expected) {
76 Verify(DecimalTypeUtil::kOpMod, "mod", x, y, expected);
77 }
78
79 protected:
80 arrow::MemoryPool* pool_;
81 };
82
MakeDecimalVector(const DecimalScalar128 & in)83 ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) {
84 std::vector<arrow::Decimal128> ret;
85
86 Decimal128 decimal_value = in.value();
87
88 auto decimal_type = std::make_shared<arrow::Decimal128Type>(in.precision(), in.scale());
89 return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true});
90 }
91
Verify(DecimalTypeUtil::Op op,const std::string & function,const DecimalScalar128 & x,const DecimalScalar128 & y,const DecimalScalar128 & expected)92 void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function,
93 const DecimalScalar128& x, const DecimalScalar128& y,
94 const DecimalScalar128& expected) {
95 auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
96 auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
97 auto field_x = field("x", x_type);
98 auto field_y = field("y", y_type);
99 auto schema = arrow::schema({field_x, field_y});
100
101 Decimal128TypePtr output_type;
102 auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type);
103 ARROW_EXPECT_OK(status);
104
105 // output fields
106 auto res = field("res", output_type);
107
108 // build expression : x op y
109 auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y}, res);
110
111 // Build a projector for the expression.
112 std::shared_ptr<Projector> projector;
113 status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
114 ARROW_EXPECT_OK(status);
115
116 // Create a row-batch with some sample data
117 auto array_a = MakeDecimalVector(x);
118 auto array_b = MakeDecimalVector(y);
119
120 // prepare input record batch
121 auto in_batch = arrow::RecordBatch::Make(schema, 1 /*num_records*/, {array_a, array_b});
122
123 // Evaluate expression
124 arrow::ArrayVector outputs;
125 status = projector->Evaluate(*in_batch, pool_, &outputs);
126 ARROW_EXPECT_OK(status);
127
128 // Validate results
129 auto out_array = dynamic_cast<arrow::Decimal128Array*>(outputs[0].get());
130 const Decimal128 out_value(out_array->GetValue(0));
131
132 auto dtype = dynamic_cast<arrow::Decimal128Type*>(out_array->type().get());
133 std::string value_string = out_value.ToString(0);
134 DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()};
135
136 EXPECT_DECIMAL_RESULT(function, x, y, expected, actual);
137 }
138
TEST_F(TestDecimalOps,TestAdd)139 TEST_F(TestDecimalOps, TestAdd) {
140 // fast-path
141 AddAndVerify(decimal_literal("201", 30, 3), // x
142 decimal_literal("301", 30, 3), // y
143 decimal_literal("502", 31, 3)); // expected
144
145 AddAndVerify(decimal_literal("201", 30, 3), // x
146 decimal_literal("301", 30, 2), // y
147 decimal_literal("3211", 32, 3)); // expected
148
149 AddAndVerify(decimal_literal("201", 30, 3), // x
150 decimal_literal("301", 30, 4), // y
151 decimal_literal("2311", 32, 4)); // expected
152
153 // max precision, but no overflow
154 AddAndVerify(decimal_literal("201", 38, 3), // x
155 decimal_literal("301", 38, 3), // y
156 decimal_literal("502", 38, 3)); // expected
157
158 AddAndVerify(decimal_literal("201", 38, 3), // x
159 decimal_literal("301", 38, 2), // y
160 decimal_literal("3211", 38, 3)); // expected
161
162 AddAndVerify(decimal_literal("201", 38, 3), // x
163 decimal_literal("301", 38, 4), // y
164 decimal_literal("2311", 38, 4)); // expected
165
166 AddAndVerify(decimal_literal("201", 38, 3), // x
167 decimal_literal("301", 38, 7), // y
168 decimal_literal("201030", 38, 6)); // expected
169
170 AddAndVerify(decimal_literal("1201", 38, 3), // x
171 decimal_literal("1801", 38, 3), // y
172 decimal_literal("3002", 38, 3)); // carry-over from fractional
173
174 // max precision
175 AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
176 decimal_literal("100", 38, 7), // y
177 decimal_literal("99999999999999999999999999999990000010", 38, 6));
178
179 AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x
180 decimal_literal("100", 38, 7), // y
181 decimal_literal("-99999999999999999999999999999989999990", 38, 6));
182
183 AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
184 decimal_literal("-100", 38, 7), // y
185 decimal_literal("99999999999999999999999999999989999990", 38, 6));
186
187 AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x
188 decimal_literal("-100", 38, 7), // y
189 decimal_literal("-99999999999999999999999999999990000010", 38, 6));
190
191 AddAndVerify(decimal_literal("09999999999999999999999999999999999999", 38, 6), // x
192 decimal_literal("89999999999999999999999999999999999999", 38, 7), // y
193 decimal_literal("18999999999999999999999999999999999999", 38, 6));
194
195 // Both -ve
196 AddAndVerify(decimal_literal("-201", 30, 3), // x
197 decimal_literal("-301", 30, 2), // y
198 decimal_literal("-3211", 32, 3)); // expected
199
200 AddAndVerify(decimal_literal("-201", 38, 3), // x
201 decimal_literal("-301", 38, 4), // y
202 decimal_literal("-2311", 38, 4)); // expected
203
204 // Mix of +ve and -ve
205 AddAndVerify(decimal_literal("-201", 30, 3), // x
206 decimal_literal("301", 30, 2), // y
207 decimal_literal("2809", 32, 3)); // expected
208
209 AddAndVerify(decimal_literal("-201", 38, 3), // x
210 decimal_literal("301", 38, 4), // y
211 decimal_literal("-1709", 38, 4)); // expected
212
213 AddAndVerify(decimal_literal("201", 38, 3), // x
214 decimal_literal("-301", 38, 7), // y
215 decimal_literal("200970", 38, 6)); // expected
216
217 AddAndVerify(decimal_literal("-1901", 38, 4), // x
218 decimal_literal("1801", 38, 4), // y
219 decimal_literal("-100", 38, 4)); // expected
220
221 AddAndVerify(decimal_literal("1801", 38, 4), // x
222 decimal_literal("-1901", 38, 4), // y
223 decimal_literal("-100", 38, 4)); // expected
224
225 // rounding +ve
226 AddAndVerify(decimal_literal("1000999", 38, 6), // x
227 decimal_literal("10000999", 38, 7), // y
228 decimal_literal("2001099", 38, 6));
229
230 AddAndVerify(decimal_literal("1000999", 38, 6), // x
231 decimal_literal("10000995", 38, 7), // y
232 decimal_literal("2001099", 38, 6));
233
234 AddAndVerify(decimal_literal("1000999", 38, 6), // x
235 decimal_literal("10000992", 38, 7), // y
236 decimal_literal("2001098", 38, 6));
237
238 // rounding -ve
239 AddAndVerify(decimal_literal("-1000999", 38, 6), // x
240 decimal_literal("-10000999", 38, 7), // y
241 decimal_literal("-2001099", 38, 6));
242
243 AddAndVerify(decimal_literal("-1000999", 38, 6), // x
244 decimal_literal("-10000995", 38, 7), // y
245 decimal_literal("-2001099", 38, 6));
246
247 AddAndVerify(decimal_literal("-1000999", 38, 6), // x
248 decimal_literal("-10000992", 38, 7), // y
249 decimal_literal("-2001098", 38, 6));
250 }
251
252 // subtract is a wrapper over add. so, minimal tests are sufficient.
TEST_F(TestDecimalOps,TestSubtract)253 TEST_F(TestDecimalOps, TestSubtract) {
254 // fast-path
255 SubtractAndVerify(decimal_literal("201", 30, 3), // x
256 decimal_literal("301", 30, 3), // y
257 decimal_literal("-100", 31, 3)); // expected
258
259 // max precision
260 SubtractAndVerify(
261 decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
262 decimal_literal("100", 38, 7), // y
263 decimal_literal("99999999999999999999999999999989999990", 38, 6));
264
265 // Mix of +ve and -ve
266 SubtractAndVerify(decimal_literal("-201", 30, 3), // x
267 decimal_literal("301", 30, 2), // y
268 decimal_literal("-3211", 32, 3)); // expected
269 }
270
271 // Lots of unit tests for multiply/divide/mod in decimal_ops_test.cc. So, keeping these
272 // basic.
TEST_F(TestDecimalOps,TestMultiply)273 TEST_F(TestDecimalOps, TestMultiply) {
274 // fast-path
275 MultiplyAndVerify(decimal_literal("201", 10, 3), // x
276 decimal_literal("301", 10, 2), // y
277 decimal_literal("60501", 21, 5)); // expected
278
279 // max precision
280 MultiplyAndVerify(DecimalScalar128(std::string(35, '9'), 38, 20), // x
281 DecimalScalar128(std::string(36, '9'), 38, 20), // x
282 DecimalScalar128("9999999999999999999999999999999999890", 38, 6));
283 }
284
TEST_F(TestDecimalOps,TestDivide)285 TEST_F(TestDecimalOps, TestDivide) {
286 DivideAndVerify(decimal_literal("201", 10, 3), // x
287 decimal_literal("301", 10, 2), // y
288 decimal_literal("6677740863787", 23, 14)); // expected
289
290 DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x
291 DecimalScalar128(std::string(35, '9'), 38, 20), // x
292 DecimalScalar128("1000000000", 38, 6));
293 }
294
TEST_F(TestDecimalOps,TestMod)295 TEST_F(TestDecimalOps, TestMod) {
296 ModAndVerify(decimal_literal("201", 20, 2), // x
297 decimal_literal("301", 20, 3), // y
298 decimal_literal("204", 20, 3)); // expected
299
300 ModAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x
301 DecimalScalar128(std::string(35, '9'), 38, 21), // x
302 DecimalScalar128("9990", 38, 21));
303 }
304
305 } // namespace gandiva
306