1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/kernels/codegen_internal.h"
19
20 #include <cmath>
21 #include <functional>
22 #include <memory>
23 #include <mutex>
24 #include <vector>
25
26 #include "arrow/type_fwd.h"
27
28 namespace arrow {
29 namespace compute {
30 namespace internal {
31
ExecFail(KernelContext * ctx,const ExecBatch & batch,Datum * out)32 Status ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
33 return Status::NotImplemented("This kernel is malformed");
34 }
35
MakeFlippedBinaryExec(ArrayKernelExec exec)36 ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec) {
37 return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
38 ExecBatch flipped_batch = batch;
39 std::swap(flipped_batch.values[0], flipped_batch.values[1]);
40 return exec(ctx, flipped_batch, out);
41 };
42 }
43
ExampleParametricTypes()44 const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes() {
45 static DataTypeVector example_parametric_types = {
46 decimal128(12, 2),
47 duration(TimeUnit::SECOND),
48 timestamp(TimeUnit::SECOND),
49 time32(TimeUnit::SECOND),
50 time64(TimeUnit::MICRO),
51 fixed_size_binary(0),
52 list(null()),
53 large_list(null()),
54 fixed_size_list(field("dummy", null()), 0),
55 struct_({}),
56 sparse_union(FieldVector{}),
57 dense_union(FieldVector{}),
58 dictionary(int32(), null()),
59 map(null(), null())};
60 return example_parametric_types;
61 }
62
FirstType(KernelContext *,const std::vector<ValueDescr> & descrs)63 Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs) {
64 ValueDescr result = descrs.front();
65 result.shape = GetBroadcastShape(descrs);
66 return result;
67 }
68
LastType(KernelContext *,const std::vector<ValueDescr> & descrs)69 Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs) {
70 ValueDescr result = descrs.back();
71 result.shape = GetBroadcastShape(descrs);
72 return result;
73 }
74
ListValuesType(KernelContext *,const std::vector<ValueDescr> & args)75 Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
76 const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
77 return ValueDescr(list_type.value_type(), GetBroadcastShape(args));
78 }
79
EnsureDictionaryDecoded(std::vector<ValueDescr> * descrs)80 void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
81 EnsureDictionaryDecoded(descrs->data(), descrs->size());
82 }
83
EnsureDictionaryDecoded(ValueDescr * begin,size_t count)84 void EnsureDictionaryDecoded(ValueDescr* begin, size_t count) {
85 auto* end = begin + count;
86 for (auto it = begin; it != end; it++) {
87 if (it->type->id() == Type::DICTIONARY) {
88 it->type = checked_cast<const DictionaryType&>(*it->type).value_type();
89 }
90 }
91 }
92
ReplaceNullWithOtherType(std::vector<ValueDescr> * descrs)93 void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs) {
94 ReplaceNullWithOtherType(descrs->data(), descrs->size());
95 }
96
ReplaceNullWithOtherType(ValueDescr * first,size_t count)97 void ReplaceNullWithOtherType(ValueDescr* first, size_t count) {
98 DCHECK_EQ(count, 2);
99
100 ValueDescr* second = first++;
101 if (first->type->id() == Type::NA) {
102 first->type = second->type;
103 return;
104 }
105
106 if (second->type->id() == Type::NA) {
107 second->type = first->type;
108 return;
109 }
110 }
111
ReplaceTypes(const std::shared_ptr<DataType> & type,std::vector<ValueDescr> * descrs)112 void ReplaceTypes(const std::shared_ptr<DataType>& type,
113 std::vector<ValueDescr>* descrs) {
114 ReplaceTypes(type, descrs->data(), descrs->size());
115 }
116
ReplaceTypes(const std::shared_ptr<DataType> & type,ValueDescr * begin,size_t count)117 void ReplaceTypes(const std::shared_ptr<DataType>& type, ValueDescr* begin,
118 size_t count) {
119 auto* end = begin + count;
120 for (auto* it = begin; it != end; it++) {
121 it->type = type;
122 }
123 }
124
CommonNumeric(const std::vector<ValueDescr> & descrs)125 std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs) {
126 return CommonNumeric(descrs.data(), descrs.size());
127 }
128
CommonNumeric(const ValueDescr * begin,size_t count)129 std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
130 DCHECK_GT(count, 0) << "tried to find CommonNumeric type of an empty set";
131
132 for (size_t i = 0; i < count; i++) {
133 const auto& descr = *(begin + i);
134 auto id = descr.type->id();
135 if (!is_floating(id) && !is_integer(id)) {
136 // a common numeric type is only possible if all types are numeric
137 return nullptr;
138 }
139 if (id == Type::HALF_FLOAT) {
140 // float16 arithmetic is not currently supported
141 return nullptr;
142 }
143 }
144
145 for (size_t i = 0; i < count; i++) {
146 const auto& descr = *(begin + i);
147 if (descr.type->id() == Type::DOUBLE) return float64();
148 }
149
150 for (size_t i = 0; i < count; i++) {
151 const auto& descr = *(begin + i);
152 if (descr.type->id() == Type::FLOAT) return float32();
153 }
154
155 int max_width_signed = 0, max_width_unsigned = 0;
156
157 for (size_t i = 0; i < count; i++) {
158 const auto& descr = *(begin + i);
159 auto id = descr.type->id();
160 auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned);
161 *max_width = std::max(bit_width(id), *max_width);
162 }
163
164 if (max_width_signed == 0) {
165 if (max_width_unsigned >= 64) return uint64();
166 if (max_width_unsigned == 32) return uint32();
167 if (max_width_unsigned == 16) return uint16();
168 DCHECK_EQ(max_width_unsigned, 8);
169 return uint8();
170 }
171
172 if (max_width_signed <= max_width_unsigned) {
173 max_width_signed = static_cast<int>(BitUtil::NextPower2(max_width_unsigned + 1));
174 }
175
176 if (max_width_signed >= 64) return int64();
177 if (max_width_signed == 32) return int32();
178 if (max_width_signed == 16) return int16();
179 DCHECK_EQ(max_width_signed, 8);
180 return int8();
181 }
182
CommonTemporal(const ValueDescr * begin,size_t count)183 std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count) {
184 TimeUnit::type finest_unit = TimeUnit::SECOND;
185 const std::string* timezone = nullptr;
186 bool saw_date32 = false;
187 bool saw_date64 = false;
188
189 const ValueDescr* end = begin + count;
190 for (auto it = begin; it != end; it++) {
191 auto id = it->type->id();
192 // a common timestamp is only possible if all types are timestamp like
193 switch (id) {
194 case Type::DATE32:
195 // Date32's unit is days, but the coarsest we have is seconds
196 saw_date32 = true;
197 continue;
198 case Type::DATE64:
199 finest_unit = std::max(finest_unit, TimeUnit::MILLI);
200 saw_date64 = true;
201 continue;
202 case Type::TIMESTAMP: {
203 const auto& ty = checked_cast<const TimestampType&>(*it->type);
204 if (timezone && *timezone != ty.timezone()) return nullptr;
205 timezone = &ty.timezone();
206 finest_unit = std::max(finest_unit, ty.unit());
207 continue;
208 }
209 default:
210 return nullptr;
211 }
212 }
213
214 if (timezone) {
215 // At least one timestamp seen
216 return timestamp(finest_unit, *timezone);
217 } else if (saw_date64) {
218 return date64();
219 } else if (saw_date32) {
220 return date32();
221 }
222 return nullptr;
223 }
224
CommonBinary(const ValueDescr * begin,size_t count)225 std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count) {
226 bool all_utf8 = true, all_offset32 = true, all_fixed_width = true;
227
228 const ValueDescr* end = begin + count;
229 for (auto it = begin; it != end; ++it) {
230 auto id = it->type->id();
231 // a common varbinary type is only possible if all types are binary like
232 switch (id) {
233 case Type::STRING:
234 all_fixed_width = false;
235 continue;
236 case Type::BINARY:
237 all_fixed_width = false;
238 all_utf8 = false;
239 continue;
240 case Type::FIXED_SIZE_BINARY:
241 all_utf8 = false;
242 continue;
243 case Type::LARGE_STRING:
244 all_offset32 = false;
245 all_fixed_width = false;
246 continue;
247 case Type::LARGE_BINARY:
248 all_offset32 = false;
249 all_fixed_width = false;
250 all_utf8 = false;
251 continue;
252 default:
253 return nullptr;
254 }
255 }
256
257 if (all_fixed_width) {
258 // At least for the purposes of comparison, no need to cast.
259 return nullptr;
260 }
261
262 if (all_utf8) {
263 if (all_offset32) return utf8();
264 return large_utf8();
265 }
266
267 if (all_offset32) return binary();
268 return large_binary();
269 }
270
CastBinaryDecimalArgs(DecimalPromotion promotion,std::vector<ValueDescr> * descrs)271 Status CastBinaryDecimalArgs(DecimalPromotion promotion,
272 std::vector<ValueDescr>* descrs) {
273 auto& left_type = (*descrs)[0].type;
274 auto& right_type = (*descrs)[1].type;
275 DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id()));
276
277 // decimal + float = float
278 if (is_floating(left_type->id())) {
279 right_type = left_type;
280 return Status::OK();
281 } else if (is_floating(right_type->id())) {
282 left_type = right_type;
283 return Status::OK();
284 }
285
286 // precision, scale of left and right args
287 int32_t p1, s1, p2, s2;
288
289 // decimal + integer = decimal
290 if (is_decimal(left_type->id())) {
291 auto decimal = checked_cast<const DecimalType*>(left_type.get());
292 p1 = decimal->precision();
293 s1 = decimal->scale();
294 } else {
295 DCHECK(is_integer(left_type->id()));
296 ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id()));
297 s1 = 0;
298 }
299 if (is_decimal(right_type->id())) {
300 auto decimal = checked_cast<const DecimalType*>(right_type.get());
301 p2 = decimal->precision();
302 s2 = decimal->scale();
303 } else {
304 DCHECK(is_integer(right_type->id()));
305 ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id()));
306 s2 = 0;
307 }
308 if (s1 < 0 || s2 < 0) {
309 return Status::NotImplemented("Decimals with negative scales not supported");
310 }
311
312 // decimal128 + decimal256 = decimal256
313 Type::type casted_type_id = Type::DECIMAL128;
314 if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) {
315 casted_type_id = Type::DECIMAL256;
316 }
317
318 // decimal promotion rules compatible with amazon redshift
319 // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
320 int32_t left_scaleup = 0;
321 int32_t right_scaleup = 0;
322
323 switch (promotion) {
324 case DecimalPromotion::kAdd: {
325 left_scaleup = std::max(s1, s2) - s1;
326 right_scaleup = std::max(s1, s2) - s2;
327 break;
328 }
329 case DecimalPromotion::kMultiply: {
330 left_scaleup = right_scaleup = 0;
331 break;
332 }
333 case DecimalPromotion::kDivide: {
334 left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1;
335 right_scaleup = 0;
336 break;
337 }
338 default:
339 DCHECK(false) << "Invalid DecimalPromotion value " << static_cast<int>(promotion);
340 }
341 ARROW_ASSIGN_OR_RAISE(
342 left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup));
343 ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup,
344 s2 + right_scaleup));
345 return Status::OK();
346 }
347
CastDecimalArgs(ValueDescr * begin,size_t count)348 Status CastDecimalArgs(ValueDescr* begin, size_t count) {
349 Type::type casted_type_id = Type::DECIMAL128;
350 auto* end = begin + count;
351
352 int32_t max_scale = 0;
353 bool any_floating = false;
354 for (auto* it = begin; it != end; ++it) {
355 const auto& ty = *it->type;
356 if (is_floating(ty.id())) {
357 // Decimal + float = float
358 any_floating = true;
359 } else if (is_integer(ty.id())) {
360 // Nothing to do here
361 } else if (is_decimal(ty.id())) {
362 max_scale = std::max(max_scale, checked_cast<const DecimalType&>(ty).scale());
363 if (ty.id() == Type::DECIMAL256) {
364 casted_type_id = Type::DECIMAL256;
365 }
366 } else {
367 // Non-numeric, can't cast
368 return Status::OK();
369 }
370 }
371 if (any_floating) {
372 ReplaceTypes(float64(), begin, count);
373 return Status::OK();
374 }
375
376 // All integer and decimal, rescale
377 int32_t common_precision = 0;
378 for (auto* it = begin; it != end; ++it) {
379 const auto& ty = *it->type;
380 if (is_integer(ty.id())) {
381 ARROW_ASSIGN_OR_RAISE(auto precision, MaxDecimalDigitsForInteger(ty.id()));
382 precision += max_scale;
383 common_precision = std::max(common_precision, precision);
384 } else if (is_decimal(ty.id())) {
385 const auto& decimal_ty = checked_cast<const DecimalType&>(ty);
386 auto precision = decimal_ty.precision();
387 const auto scale = decimal_ty.scale();
388 precision += max_scale - scale;
389 common_precision = std::max(common_precision, precision);
390 }
391 }
392
393 if (common_precision > BasicDecimal256::kMaxPrecision) {
394 return Status::Invalid("Result precision (", common_precision,
395 ") exceeds max precision of Decimal256 (",
396 BasicDecimal256::kMaxPrecision, ")");
397 } else if (common_precision > BasicDecimal128::kMaxPrecision) {
398 casted_type_id = Type::DECIMAL256;
399 }
400
401 for (auto* it = begin; it != end; ++it) {
402 ARROW_ASSIGN_OR_RAISE(it->type,
403 DecimalType::Make(casted_type_id, common_precision, max_scale));
404 }
405
406 return Status::OK();
407 }
408
HasDecimal(const std::vector<ValueDescr> & descrs)409 bool HasDecimal(const std::vector<ValueDescr>& descrs) {
410 for (const auto& descr : descrs) {
411 if (is_decimal(descr.type->id())) {
412 return true;
413 }
414 }
415 return false;
416 }
417
418 } // namespace internal
419 } // namespace compute
420 } // namespace arrow
421