1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/compute/kernels/codegen_internal.h"
19 
20 #include <cmath>
21 #include <functional>
22 #include <memory>
23 #include <mutex>
24 #include <vector>
25 
26 #include "arrow/type_fwd.h"
27 
28 namespace arrow {
29 namespace compute {
30 namespace internal {
31 
ExecFail(KernelContext * ctx,const ExecBatch & batch,Datum * out)32 Status ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
33   return Status::NotImplemented("This kernel is malformed");
34 }
35 
MakeFlippedBinaryExec(ArrayKernelExec exec)36 ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec) {
37   return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
38     ExecBatch flipped_batch = batch;
39     std::swap(flipped_batch.values[0], flipped_batch.values[1]);
40     return exec(ctx, flipped_batch, out);
41   };
42 }
43 
ExampleParametricTypes()44 const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes() {
45   static DataTypeVector example_parametric_types = {
46       decimal128(12, 2),
47       duration(TimeUnit::SECOND),
48       timestamp(TimeUnit::SECOND),
49       time32(TimeUnit::SECOND),
50       time64(TimeUnit::MICRO),
51       fixed_size_binary(0),
52       list(null()),
53       large_list(null()),
54       fixed_size_list(field("dummy", null()), 0),
55       struct_({}),
56       sparse_union(FieldVector{}),
57       dense_union(FieldVector{}),
58       dictionary(int32(), null()),
59       map(null(), null())};
60   return example_parametric_types;
61 }
62 
FirstType(KernelContext *,const std::vector<ValueDescr> & descrs)63 Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs) {
64   ValueDescr result = descrs.front();
65   result.shape = GetBroadcastShape(descrs);
66   return result;
67 }
68 
LastType(KernelContext *,const std::vector<ValueDescr> & descrs)69 Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs) {
70   ValueDescr result = descrs.back();
71   result.shape = GetBroadcastShape(descrs);
72   return result;
73 }
74 
ListValuesType(KernelContext *,const std::vector<ValueDescr> & args)75 Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
76   const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
77   return ValueDescr(list_type.value_type(), GetBroadcastShape(args));
78 }
79 
EnsureDictionaryDecoded(std::vector<ValueDescr> * descrs)80 void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
81   EnsureDictionaryDecoded(descrs->data(), descrs->size());
82 }
83 
EnsureDictionaryDecoded(ValueDescr * begin,size_t count)84 void EnsureDictionaryDecoded(ValueDescr* begin, size_t count) {
85   auto* end = begin + count;
86   for (auto it = begin; it != end; it++) {
87     if (it->type->id() == Type::DICTIONARY) {
88       it->type = checked_cast<const DictionaryType&>(*it->type).value_type();
89     }
90   }
91 }
92 
ReplaceNullWithOtherType(std::vector<ValueDescr> * descrs)93 void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs) {
94   ReplaceNullWithOtherType(descrs->data(), descrs->size());
95 }
96 
ReplaceNullWithOtherType(ValueDescr * first,size_t count)97 void ReplaceNullWithOtherType(ValueDescr* first, size_t count) {
98   DCHECK_EQ(count, 2);
99 
100   ValueDescr* second = first++;
101   if (first->type->id() == Type::NA) {
102     first->type = second->type;
103     return;
104   }
105 
106   if (second->type->id() == Type::NA) {
107     second->type = first->type;
108     return;
109   }
110 }
111 
ReplaceTypes(const std::shared_ptr<DataType> & type,std::vector<ValueDescr> * descrs)112 void ReplaceTypes(const std::shared_ptr<DataType>& type,
113                   std::vector<ValueDescr>* descrs) {
114   ReplaceTypes(type, descrs->data(), descrs->size());
115 }
116 
ReplaceTypes(const std::shared_ptr<DataType> & type,ValueDescr * begin,size_t count)117 void ReplaceTypes(const std::shared_ptr<DataType>& type, ValueDescr* begin,
118                   size_t count) {
119   auto* end = begin + count;
120   for (auto* it = begin; it != end; it++) {
121     it->type = type;
122   }
123 }
124 
CommonNumeric(const std::vector<ValueDescr> & descrs)125 std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs) {
126   return CommonNumeric(descrs.data(), descrs.size());
127 }
128 
CommonNumeric(const ValueDescr * begin,size_t count)129 std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
130   DCHECK_GT(count, 0) << "tried to find CommonNumeric type of an empty set";
131 
132   for (size_t i = 0; i < count; i++) {
133     const auto& descr = *(begin + i);
134     auto id = descr.type->id();
135     if (!is_floating(id) && !is_integer(id)) {
136       // a common numeric type is only possible if all types are numeric
137       return nullptr;
138     }
139     if (id == Type::HALF_FLOAT) {
140       // float16 arithmetic is not currently supported
141       return nullptr;
142     }
143   }
144 
145   for (size_t i = 0; i < count; i++) {
146     const auto& descr = *(begin + i);
147     if (descr.type->id() == Type::DOUBLE) return float64();
148   }
149 
150   for (size_t i = 0; i < count; i++) {
151     const auto& descr = *(begin + i);
152     if (descr.type->id() == Type::FLOAT) return float32();
153   }
154 
155   int max_width_signed = 0, max_width_unsigned = 0;
156 
157   for (size_t i = 0; i < count; i++) {
158     const auto& descr = *(begin + i);
159     auto id = descr.type->id();
160     auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned);
161     *max_width = std::max(bit_width(id), *max_width);
162   }
163 
164   if (max_width_signed == 0) {
165     if (max_width_unsigned >= 64) return uint64();
166     if (max_width_unsigned == 32) return uint32();
167     if (max_width_unsigned == 16) return uint16();
168     DCHECK_EQ(max_width_unsigned, 8);
169     return uint8();
170   }
171 
172   if (max_width_signed <= max_width_unsigned) {
173     max_width_signed = static_cast<int>(BitUtil::NextPower2(max_width_unsigned + 1));
174   }
175 
176   if (max_width_signed >= 64) return int64();
177   if (max_width_signed == 32) return int32();
178   if (max_width_signed == 16) return int16();
179   DCHECK_EQ(max_width_signed, 8);
180   return int8();
181 }
182 
CommonTemporal(const ValueDescr * begin,size_t count)183 std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count) {
184   TimeUnit::type finest_unit = TimeUnit::SECOND;
185   const std::string* timezone = nullptr;
186   bool saw_date32 = false;
187   bool saw_date64 = false;
188 
189   const ValueDescr* end = begin + count;
190   for (auto it = begin; it != end; it++) {
191     auto id = it->type->id();
192     // a common timestamp is only possible if all types are timestamp like
193     switch (id) {
194       case Type::DATE32:
195         // Date32's unit is days, but the coarsest we have is seconds
196         saw_date32 = true;
197         continue;
198       case Type::DATE64:
199         finest_unit = std::max(finest_unit, TimeUnit::MILLI);
200         saw_date64 = true;
201         continue;
202       case Type::TIMESTAMP: {
203         const auto& ty = checked_cast<const TimestampType&>(*it->type);
204         if (timezone && *timezone != ty.timezone()) return nullptr;
205         timezone = &ty.timezone();
206         finest_unit = std::max(finest_unit, ty.unit());
207         continue;
208       }
209       default:
210         return nullptr;
211     }
212   }
213 
214   if (timezone) {
215     // At least one timestamp seen
216     return timestamp(finest_unit, *timezone);
217   } else if (saw_date64) {
218     return date64();
219   } else if (saw_date32) {
220     return date32();
221   }
222   return nullptr;
223 }
224 
CommonBinary(const ValueDescr * begin,size_t count)225 std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count) {
226   bool all_utf8 = true, all_offset32 = true, all_fixed_width = true;
227 
228   const ValueDescr* end = begin + count;
229   for (auto it = begin; it != end; ++it) {
230     auto id = it->type->id();
231     // a common varbinary type is only possible if all types are binary like
232     switch (id) {
233       case Type::STRING:
234         all_fixed_width = false;
235         continue;
236       case Type::BINARY:
237         all_fixed_width = false;
238         all_utf8 = false;
239         continue;
240       case Type::FIXED_SIZE_BINARY:
241         all_utf8 = false;
242         continue;
243       case Type::LARGE_STRING:
244         all_offset32 = false;
245         all_fixed_width = false;
246         continue;
247       case Type::LARGE_BINARY:
248         all_offset32 = false;
249         all_fixed_width = false;
250         all_utf8 = false;
251         continue;
252       default:
253         return nullptr;
254     }
255   }
256 
257   if (all_fixed_width) {
258     // At least for the purposes of comparison, no need to cast.
259     return nullptr;
260   }
261 
262   if (all_utf8) {
263     if (all_offset32) return utf8();
264     return large_utf8();
265   }
266 
267   if (all_offset32) return binary();
268   return large_binary();
269 }
270 
CastBinaryDecimalArgs(DecimalPromotion promotion,std::vector<ValueDescr> * descrs)271 Status CastBinaryDecimalArgs(DecimalPromotion promotion,
272                              std::vector<ValueDescr>* descrs) {
273   auto& left_type = (*descrs)[0].type;
274   auto& right_type = (*descrs)[1].type;
275   DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id()));
276 
277   // decimal + float = float
278   if (is_floating(left_type->id())) {
279     right_type = left_type;
280     return Status::OK();
281   } else if (is_floating(right_type->id())) {
282     left_type = right_type;
283     return Status::OK();
284   }
285 
286   // precision, scale of left and right args
287   int32_t p1, s1, p2, s2;
288 
289   // decimal + integer = decimal
290   if (is_decimal(left_type->id())) {
291     auto decimal = checked_cast<const DecimalType*>(left_type.get());
292     p1 = decimal->precision();
293     s1 = decimal->scale();
294   } else {
295     DCHECK(is_integer(left_type->id()));
296     ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id()));
297     s1 = 0;
298   }
299   if (is_decimal(right_type->id())) {
300     auto decimal = checked_cast<const DecimalType*>(right_type.get());
301     p2 = decimal->precision();
302     s2 = decimal->scale();
303   } else {
304     DCHECK(is_integer(right_type->id()));
305     ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id()));
306     s2 = 0;
307   }
308   if (s1 < 0 || s2 < 0) {
309     return Status::NotImplemented("Decimals with negative scales not supported");
310   }
311 
312   // decimal128 + decimal256 = decimal256
313   Type::type casted_type_id = Type::DECIMAL128;
314   if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) {
315     casted_type_id = Type::DECIMAL256;
316   }
317 
318   // decimal promotion rules compatible with amazon redshift
319   // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
320   int32_t left_scaleup = 0;
321   int32_t right_scaleup = 0;
322 
323   switch (promotion) {
324     case DecimalPromotion::kAdd: {
325       left_scaleup = std::max(s1, s2) - s1;
326       right_scaleup = std::max(s1, s2) - s2;
327       break;
328     }
329     case DecimalPromotion::kMultiply: {
330       left_scaleup = right_scaleup = 0;
331       break;
332     }
333     case DecimalPromotion::kDivide: {
334       left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1;
335       right_scaleup = 0;
336       break;
337     }
338     default:
339       DCHECK(false) << "Invalid DecimalPromotion value " << static_cast<int>(promotion);
340   }
341   ARROW_ASSIGN_OR_RAISE(
342       left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup));
343   ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup,
344                                                       s2 + right_scaleup));
345   return Status::OK();
346 }
347 
CastDecimalArgs(ValueDescr * begin,size_t count)348 Status CastDecimalArgs(ValueDescr* begin, size_t count) {
349   Type::type casted_type_id = Type::DECIMAL128;
350   auto* end = begin + count;
351 
352   int32_t max_scale = 0;
353   bool any_floating = false;
354   for (auto* it = begin; it != end; ++it) {
355     const auto& ty = *it->type;
356     if (is_floating(ty.id())) {
357       // Decimal + float = float
358       any_floating = true;
359     } else if (is_integer(ty.id())) {
360       // Nothing to do here
361     } else if (is_decimal(ty.id())) {
362       max_scale = std::max(max_scale, checked_cast<const DecimalType&>(ty).scale());
363       if (ty.id() == Type::DECIMAL256) {
364         casted_type_id = Type::DECIMAL256;
365       }
366     } else {
367       // Non-numeric, can't cast
368       return Status::OK();
369     }
370   }
371   if (any_floating) {
372     ReplaceTypes(float64(), begin, count);
373     return Status::OK();
374   }
375 
376   // All integer and decimal, rescale
377   int32_t common_precision = 0;
378   for (auto* it = begin; it != end; ++it) {
379     const auto& ty = *it->type;
380     if (is_integer(ty.id())) {
381       ARROW_ASSIGN_OR_RAISE(auto precision, MaxDecimalDigitsForInteger(ty.id()));
382       precision += max_scale;
383       common_precision = std::max(common_precision, precision);
384     } else if (is_decimal(ty.id())) {
385       const auto& decimal_ty = checked_cast<const DecimalType&>(ty);
386       auto precision = decimal_ty.precision();
387       const auto scale = decimal_ty.scale();
388       precision += max_scale - scale;
389       common_precision = std::max(common_precision, precision);
390     }
391   }
392 
393   if (common_precision > BasicDecimal256::kMaxPrecision) {
394     return Status::Invalid("Result precision (", common_precision,
395                            ") exceeds max precision of Decimal256 (",
396                            BasicDecimal256::kMaxPrecision, ")");
397   } else if (common_precision > BasicDecimal128::kMaxPrecision) {
398     casted_type_id = Type::DECIMAL256;
399   }
400 
401   for (auto* it = begin; it != end; ++it) {
402     ARROW_ASSIGN_OR_RAISE(it->type,
403                           DecimalType::Make(casted_type_id, common_precision, max_scale));
404   }
405 
406   return Status::OK();
407 }
408 
HasDecimal(const std::vector<ValueDescr> & descrs)409 bool HasDecimal(const std::vector<ValueDescr>& descrs) {
410   for (const auto& descr : descrs) {
411     if (is_decimal(descr.type->id())) {
412       return true;
413     }
414   }
415   return false;
416 }
417 
418 }  // namespace internal
419 }  // namespace compute
420 }  // namespace arrow
421