1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 // Implementation of casting to integer, floating point, or decimal types
19 
20 #include "arrow/array/builder_primitive.h"
21 #include "arrow/compute/kernels/common.h"
22 #include "arrow/compute/kernels/scalar_cast_internal.h"
23 #include "arrow/compute/kernels/util_internal.h"
24 #include "arrow/util/bit_block_counter.h"
25 #include "arrow/util/int_util.h"
26 #include "arrow/util/value_parsing.h"
27 
28 namespace arrow {
29 
30 using internal::BitBlockCount;
31 using internal::CheckIntegersInRange;
32 using internal::IntegersCanFit;
33 using internal::OptionalBitBlockCounter;
34 using internal::ParseValue;
35 
36 namespace compute {
37 namespace internal {
38 
CastIntegerToInteger(KernelContext * ctx,const ExecBatch & batch,Datum * out)39 Status CastIntegerToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
40   const auto& options = checked_cast<const CastState*>(ctx->state())->options;
41   if (!options.allow_int_overflow) {
42     RETURN_NOT_OK(IntegersCanFit(batch[0], *out->type()));
43   }
44   CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
45   return Status::OK();
46 }
47 
CastFloatingToFloating(KernelContext *,const ExecBatch & batch,Datum * out)48 Status CastFloatingToFloating(KernelContext*, const ExecBatch& batch, Datum* out) {
49   CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
50   return Status::OK();
51 }
52 
53 // ----------------------------------------------------------------------
54 // Implement fast safe floating point to integer cast
55 
56 // InType is a floating point type we are planning to cast to integer
57 template <typename InType, typename OutType, typename InT = typename InType::c_type,
58           typename OutT = typename OutType::c_type>
59 ARROW_DISABLE_UBSAN("float-cast-overflow")
CheckFloatTruncation(const Datum & input,const Datum & output)60 Status CheckFloatTruncation(const Datum& input, const Datum& output) {
61   auto WasTruncated = [&](OutT out_val, InT in_val) -> bool {
62     return static_cast<InT>(out_val) != in_val;
63   };
64   auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool {
65     return is_valid && static_cast<InT>(out_val) != in_val;
66   };
67   auto GetErrorMessage = [&](InT val) {
68     return Status::Invalid("Float value ", val, " was truncated converting to ",
69                            *output.type());
70   };
71 
72   if (input.kind() == Datum::SCALAR) {
73     DCHECK_EQ(output.kind(), Datum::SCALAR);
74     const auto& in_scalar = input.scalar_as<typename TypeTraits<InType>::ScalarType>();
75     const auto& out_scalar = output.scalar_as<typename TypeTraits<OutType>::ScalarType>();
76     if (WasTruncatedMaybeNull(out_scalar.value, in_scalar.value, out_scalar.is_valid)) {
77       return GetErrorMessage(in_scalar.value);
78     }
79     return Status::OK();
80   }
81 
82   const ArrayData& in_array = *input.array();
83   const ArrayData& out_array = *output.array();
84 
85   const InT* in_data = in_array.GetValues<InT>(1);
86   const OutT* out_data = out_array.GetValues<OutT>(1);
87 
88   const uint8_t* bitmap = nullptr;
89   if (in_array.buffers[0]) {
90     bitmap = in_array.buffers[0]->data();
91   }
92   OptionalBitBlockCounter bit_counter(bitmap, in_array.offset, in_array.length);
93   int64_t position = 0;
94   int64_t offset_position = in_array.offset;
95   while (position < in_array.length) {
96     BitBlockCount block = bit_counter.NextBlock();
97     bool block_out_of_bounds = false;
98     if (block.popcount == block.length) {
99       // Fast path: branchless
100       for (int64_t i = 0; i < block.length; ++i) {
101         block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]);
102       }
103     } else if (block.popcount > 0) {
104       // Indices have nulls, must only boundscheck non-null values
105       for (int64_t i = 0; i < block.length; ++i) {
106         block_out_of_bounds |= WasTruncatedMaybeNull(
107             out_data[i], in_data[i], BitUtil::GetBit(bitmap, offset_position + i));
108       }
109     }
110     if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
111       if (in_array.GetNullCount() > 0) {
112         for (int64_t i = 0; i < block.length; ++i) {
113           if (WasTruncatedMaybeNull(out_data[i], in_data[i],
114                                     BitUtil::GetBit(bitmap, offset_position + i))) {
115             return GetErrorMessage(in_data[i]);
116           }
117         }
118       } else {
119         for (int64_t i = 0; i < block.length; ++i) {
120           if (WasTruncated(out_data[i], in_data[i])) {
121             return GetErrorMessage(in_data[i]);
122           }
123         }
124       }
125     }
126     in_data += block.length;
127     out_data += block.length;
128     position += block.length;
129     offset_position += block.length;
130   }
131   return Status::OK();
132 }
133 
134 template <typename InType>
CheckFloatToIntTruncationImpl(const Datum & input,const Datum & output)135 Status CheckFloatToIntTruncationImpl(const Datum& input, const Datum& output) {
136   switch (output.type()->id()) {
137     case Type::INT8:
138       return CheckFloatTruncation<InType, Int8Type>(input, output);
139     case Type::INT16:
140       return CheckFloatTruncation<InType, Int16Type>(input, output);
141     case Type::INT32:
142       return CheckFloatTruncation<InType, Int32Type>(input, output);
143     case Type::INT64:
144       return CheckFloatTruncation<InType, Int64Type>(input, output);
145     case Type::UINT8:
146       return CheckFloatTruncation<InType, UInt8Type>(input, output);
147     case Type::UINT16:
148       return CheckFloatTruncation<InType, UInt16Type>(input, output);
149     case Type::UINT32:
150       return CheckFloatTruncation<InType, UInt32Type>(input, output);
151     case Type::UINT64:
152       return CheckFloatTruncation<InType, UInt64Type>(input, output);
153     default:
154       break;
155   }
156   DCHECK(false);
157   return Status::OK();
158 }
159 
CheckFloatToIntTruncation(const Datum & input,const Datum & output)160 Status CheckFloatToIntTruncation(const Datum& input, const Datum& output) {
161   switch (input.type()->id()) {
162     case Type::FLOAT:
163       return CheckFloatToIntTruncationImpl<FloatType>(input, output);
164     case Type::DOUBLE:
165       return CheckFloatToIntTruncationImpl<DoubleType>(input, output);
166     default:
167       break;
168   }
169   DCHECK(false);
170   return Status::OK();
171 }
172 
CastFloatingToInteger(KernelContext * ctx,const ExecBatch & batch,Datum * out)173 Status CastFloatingToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
174   const auto& options = checked_cast<const CastState*>(ctx->state())->options;
175   CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
176   if (!options.allow_float_truncate) {
177     RETURN_NOT_OK(CheckFloatToIntTruncation(batch[0], *out));
178   }
179   return Status::OK();
180 }
181 
182 // ----------------------------------------------------------------------
183 // Implement fast integer to floating point cast
184 
185 // These are the limits for exact representation of whole numbers in floating
186 // point numbers
187 template <typename T>
188 struct FloatingIntegerBound {};
189 
190 template <>
191 struct FloatingIntegerBound<float> {
192   static const int64_t value = 1LL << 24;
193 };
194 
195 template <>
196 struct FloatingIntegerBound<double> {
197   static const int64_t value = 1LL << 53;
198 };
199 
200 template <typename InType, typename OutType, typename InT = typename InType::c_type,
201           typename OutT = typename OutType::c_type,
202           bool IsSigned = is_signed_integer_type<InType>::value>
CheckIntegerFloatTruncateImpl(const Datum & input)203 Status CheckIntegerFloatTruncateImpl(const Datum& input) {
204   using InScalarType = typename TypeTraits<InType>::ScalarType;
205   const int64_t limit = FloatingIntegerBound<OutT>::value;
206   InScalarType bound_lower(IsSigned ? -limit : 0);
207   InScalarType bound_upper(limit);
208   return CheckIntegersInRange(input, bound_lower, bound_upper);
209 }
210 
CheckForIntegerToFloatingTruncation(const Datum & input,Type::type out_type)211 Status CheckForIntegerToFloatingTruncation(const Datum& input, Type::type out_type) {
212   switch (input.type()->id()) {
213     // Small integers are all exactly representable as whole numbers
214     case Type::INT8:
215     case Type::INT16:
216     case Type::UINT8:
217     case Type::UINT16:
218       return Status::OK();
219     case Type::INT32: {
220       if (out_type == Type::DOUBLE) {
221         return Status::OK();
222       }
223       return CheckIntegerFloatTruncateImpl<Int32Type, FloatType>(input);
224     }
225     case Type::UINT32: {
226       if (out_type == Type::DOUBLE) {
227         return Status::OK();
228       }
229       return CheckIntegerFloatTruncateImpl<UInt32Type, FloatType>(input);
230     }
231     case Type::INT64: {
232       if (out_type == Type::FLOAT) {
233         return CheckIntegerFloatTruncateImpl<Int64Type, FloatType>(input);
234       } else {
235         return CheckIntegerFloatTruncateImpl<Int64Type, DoubleType>(input);
236       }
237     }
238     case Type::UINT64: {
239       if (out_type == Type::FLOAT) {
240         return CheckIntegerFloatTruncateImpl<UInt64Type, FloatType>(input);
241       } else {
242         return CheckIntegerFloatTruncateImpl<UInt64Type, DoubleType>(input);
243       }
244     }
245     default:
246       break;
247   }
248   DCHECK(false);
249   return Status::OK();
250 }
251 
CastIntegerToFloating(KernelContext * ctx,const ExecBatch & batch,Datum * out)252 Status CastIntegerToFloating(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
253   const auto& options = checked_cast<const CastState*>(ctx->state())->options;
254   Type::type out_type = out->type()->id();
255   if (!options.allow_float_truncate) {
256     RETURN_NOT_OK(CheckForIntegerToFloatingTruncation(batch[0], out_type));
257   }
258   CastNumberToNumberUnsafe(batch[0].type()->id(), out_type, batch[0], out);
259   return Status::OK();
260 }
261 
262 // ----------------------------------------------------------------------
263 // Boolean to number
264 
265 struct BooleanToNumber {
266   template <typename OutValue, typename Arg0Value>
Callarrow::compute::internal::BooleanToNumber267   static OutValue Call(KernelContext*, Arg0Value val, Status*) {
268     constexpr auto kOne = static_cast<OutValue>(1);
269     constexpr auto kZero = static_cast<OutValue>(0);
270     return val ? kOne : kZero;
271   }
272 };
273 
274 template <typename O>
275 struct CastFunctor<O, BooleanType, enable_if_number<O>> {
Execarrow::compute::internal::CastFunctor276   static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
277     return applicator::ScalarUnary<O, BooleanType, BooleanToNumber>::Exec(ctx, batch,
278                                                                           out);
279   }
280 };
281 
282 // ----------------------------------------------------------------------
283 // String to number
284 
285 template <typename OutType>
286 struct ParseString {
287   template <typename OutValue, typename Arg0Value>
288   OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
289     OutValue result = OutValue(0);
290     if (ARROW_PREDICT_FALSE(!ParseValue<OutType>(val.data(), val.size(), &result))) {
291       *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ",
292                             TypeTraits<OutType>::type_singleton()->ToString());
293     }
294     return result;
295   }
296 };
297 
298 template <typename O, typename I>
299 struct CastFunctor<O, I, enable_if_base_binary<I>> {
Execarrow::compute::internal::CastFunctor300   static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
301     return applicator::ScalarUnaryNotNull<O, I, ParseString<O>>::Exec(ctx, batch, out);
302   }
303 };
304 
305 // ----------------------------------------------------------------------
306 // Decimal to integer
307 
308 struct DecimalToIntegerMixin {
309   template <typename OutValue, typename Arg0Value>
310   OutValue ToInteger(KernelContext* ctx, const Arg0Value& val, Status* st) const {
311     constexpr auto min_value = std::numeric_limits<OutValue>::min();
312     constexpr auto max_value = std::numeric_limits<OutValue>::max();
313 
314     if (!allow_int_overflow_ && ARROW_PREDICT_FALSE(val < min_value || val > max_value)) {
315       *st = Status::Invalid("Integer value out of bounds");
316       return OutValue{};  // Zero
317     } else {
318       return static_cast<OutValue>(val.low_bits());
319     }
320   }
321 
DecimalToIntegerMixinarrow::compute::internal::DecimalToIntegerMixin322   DecimalToIntegerMixin(int32_t in_scale, bool allow_int_overflow)
323       : in_scale_(in_scale), allow_int_overflow_(allow_int_overflow) {}
324 
325   int32_t in_scale_;
326   bool allow_int_overflow_;
327 };
328 
329 struct UnsafeUpscaleDecimalToInteger : public DecimalToIntegerMixin {
330   using DecimalToIntegerMixin::DecimalToIntegerMixin;
331 
332   template <typename OutValue, typename Arg0Value>
333   OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
334     return ToInteger<OutValue>(ctx, val.IncreaseScaleBy(-in_scale_), st);
335   }
336 };
337 
338 struct UnsafeDownscaleDecimalToInteger : public DecimalToIntegerMixin {
339   using DecimalToIntegerMixin::DecimalToIntegerMixin;
340 
341   template <typename OutValue, typename Arg0Value>
342   OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
343     return ToInteger<OutValue>(ctx, val.ReduceScaleBy(in_scale_, false), st);
344   }
345 };
346 
347 struct SafeRescaleDecimalToInteger : public DecimalToIntegerMixin {
348   using DecimalToIntegerMixin::DecimalToIntegerMixin;
349 
350   template <typename OutValue, typename Arg0Value>
351   OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
352     auto result = val.Rescale(in_scale_, 0);
353     if (ARROW_PREDICT_FALSE(!result.ok())) {
354       *st = result.status();
355       return OutValue{};  // Zero
356     } else {
357       return ToInteger<OutValue>(ctx, *result, st);
358     }
359   }
360 };
361 
362 template <typename O, typename I>
363 struct CastFunctor<O, I,
364                    enable_if_t<is_integer_type<O>::value && is_decimal_type<I>::value>> {
365   using out_type = typename O::c_type;
366 
Execarrow::compute::internal::CastFunctor367   static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
368     const auto& options = checked_cast<const CastState*>(ctx->state())->options;
369 
370     const auto& in_type_inst = checked_cast<const I&>(*batch[0].type());
371     const auto in_scale = in_type_inst.scale();
372 
373     if (options.allow_decimal_truncate) {
374       if (in_scale < 0) {
375         // Unsafe upscale
376         applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimalToInteger>
377             kernel(UnsafeUpscaleDecimalToInteger{in_scale, options.allow_int_overflow});
378         return kernel.Exec(ctx, batch, out);
379       } else {
380         // Unsafe downscale
381         applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimalToInteger>
382             kernel(UnsafeDownscaleDecimalToInteger{in_scale, options.allow_int_overflow});
383         return kernel.Exec(ctx, batch, out);
384       }
385     } else {
386       // Safe rescale
387       applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimalToInteger> kernel(
388           SafeRescaleDecimalToInteger{in_scale, options.allow_int_overflow});
389       return kernel.Exec(ctx, batch, out);
390     }
391   }
392 };
393 
394 // ----------------------------------------------------------------------
395 // Integer to decimal
396 
397 struct IntegerToDecimal {
398   template <typename OutValue, typename IntegerType>
399   OutValue Call(KernelContext*, IntegerType val, Status* st) const {
400     auto maybe_decimal = OutValue(val).Rescale(0, out_scale_);
401     if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) {
402       return maybe_decimal.MoveValueUnsafe();
403     }
404     *st = maybe_decimal.status();
405     return OutValue{};
406   }
407 
408   int32_t out_scale_;
409 };
410 
411 template <typename O, typename I>
412 struct CastFunctor<O, I,
413                    enable_if_t<is_decimal_type<O>::value && is_integer_type<I>::value>> {
Execarrow::compute::internal::CastFunctor414   static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
415     const auto& out_type = checked_cast<const O&>(*out->type());
416     const auto out_scale = out_type.scale();
417     const auto out_precision = out_type.precision();
418 
419     // verify precision and scale
420     if (out_scale < 0) {
421       return Status::Invalid("Scale must be non-negative");
422     }
423     ARROW_ASSIGN_OR_RAISE(int32_t precision, MaxDecimalDigitsForInteger(I::type_id));
424     precision += out_scale;
425     if (out_precision < precision) {
426       return Status::Invalid(
427           "Precision is not great enough for the result. "
428           "It should be at least ",
429           precision);
430     }
431 
432     applicator::ScalarUnaryNotNullStateful<O, I, IntegerToDecimal> kernel(
433         IntegerToDecimal{out_scale});
434     return kernel.Exec(ctx, batch, out);
435   }
436 };
437 
438 // ----------------------------------------------------------------------
439 // Decimal to decimal
440 
441 // Helper that converts the input and output decimals
442 // For instance, Decimal128 -> Decimal256 requires converting, then scaling
443 // Decimal256 -> Decimal128 requires scaling, then truncating
444 template <typename OutDecimal, typename InDecimal>
445 struct DecimalConversions {};
446 
447 template <typename InDecimal>
448 struct DecimalConversions<Decimal256, InDecimal> {
449   // Convert then scale
ConvertInputarrow::compute::internal::DecimalConversions450   static Decimal256 ConvertInput(InDecimal&& val) { return Decimal256(val); }
ConvertOutputarrow::compute::internal::DecimalConversions451   static Decimal256 ConvertOutput(Decimal256&& val) { return val; }
452 };
453 
454 template <>
455 struct DecimalConversions<Decimal128, Decimal256> {
456   // Scale then truncate
ConvertInputarrow::compute::internal::DecimalConversions457   static Decimal256 ConvertInput(Decimal256&& val) { return val; }
ConvertOutputarrow::compute::internal::DecimalConversions458   static Decimal128 ConvertOutput(Decimal256&& val) {
459     const auto array_le = BitUtil::LittleEndianArray::Make(val.native_endian_array());
460     return Decimal128(array_le[1], array_le[0]);
461   }
462 };
463 
464 template <>
465 struct DecimalConversions<Decimal128, Decimal128> {
ConvertInputarrow::compute::internal::DecimalConversions466   static Decimal128 ConvertInput(Decimal128&& val) { return val; }
ConvertOutputarrow::compute::internal::DecimalConversions467   static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
468 };
469 
470 struct UnsafeUpscaleDecimal {
471   template <typename OutValue, typename Arg0Value>
472   OutValue Call(KernelContext*, Arg0Value val, Status*) const {
473     using Conv = DecimalConversions<OutValue, Arg0Value>;
474     return Conv::ConvertOutput(Conv::ConvertInput(std::move(val)).IncreaseScaleBy(by_));
475   }
476   int32_t by_;
477 };
478 
479 struct UnsafeDownscaleDecimal {
480   template <typename OutValue, typename Arg0Value>
481   OutValue Call(KernelContext*, Arg0Value val, Status*) const {
482     using Conv = DecimalConversions<OutValue, Arg0Value>;
483     return Conv::ConvertOutput(
484         Conv::ConvertInput(std::move(val)).ReduceScaleBy(by_, false));
485   }
486   int32_t by_;
487 };
488 
489 struct SafeRescaleDecimal {
490   template <typename OutValue, typename Arg0Value>
491   OutValue Call(KernelContext*, Arg0Value val, Status* st) const {
492     using Conv = DecimalConversions<OutValue, Arg0Value>;
493     auto maybe_rescaled =
494         Conv::ConvertInput(std::move(val)).Rescale(in_scale_, out_scale_);
495     if (ARROW_PREDICT_FALSE(!maybe_rescaled.ok())) {
496       *st = maybe_rescaled.status();
497       return {};  // Zero
498     }
499 
500     if (ARROW_PREDICT_TRUE(maybe_rescaled->FitsInPrecision(out_precision_))) {
501       return Conv::ConvertOutput(maybe_rescaled.MoveValueUnsafe());
502     }
503 
504     *st = Status::Invalid("Decimal value does not fit in precision ", out_precision_);
505     return {};  // Zero
506   }
507 
508   int32_t out_scale_, out_precision_, in_scale_;
509 };
510 
511 template <typename O, typename I>
512 struct CastFunctor<O, I,
513                    enable_if_t<is_decimal_type<O>::value && is_decimal_type<I>::value>> {
Execarrow::compute::internal::CastFunctor514   static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
515     const auto& options = checked_cast<const CastState*>(ctx->state())->options;
516 
517     const auto& in_type = checked_cast<const I&>(*batch[0].type());
518     const auto& out_type = checked_cast<const O&>(*out->type());
519     const auto in_scale = in_type.scale();
520     const auto out_scale = out_type.scale();
521 
522     if (options.allow_decimal_truncate) {
523       if (in_scale < out_scale) {
524         // Unsafe upscale
525         applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimal> kernel(
526             UnsafeUpscaleDecimal{out_scale - in_scale});
527         return kernel.Exec(ctx, batch, out);
528       } else {
529         // Unsafe downscale
530         applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimal> kernel(
531             UnsafeDownscaleDecimal{in_scale - out_scale});
532         return kernel.Exec(ctx, batch, out);
533       }
534     }
535 
536     // Safe rescale
537     applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimal> kernel(
538         SafeRescaleDecimal{out_scale, out_type.precision(), in_scale});
539     return kernel.Exec(ctx, batch, out);
540   }
541 };
542 
543 // ----------------------------------------------------------------------
544 // Real to decimal
545 
546 struct RealToDecimal {
547   template <typename OutValue, typename RealType>
548   OutValue Call(KernelContext*, RealType val, Status* st) const {
549     auto maybe_decimal = OutValue::FromReal(val, out_precision_, out_scale_);
550 
551     if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) {
552       return maybe_decimal.MoveValueUnsafe();
553     }
554 
555     if (!allow_truncate_) {
556       *st = maybe_decimal.status();
557     }
558     return {};  // Zero
559   }
560 
561   int32_t out_scale_, out_precision_;
562   bool allow_truncate_;
563 };
564 
565 template <typename O, typename I>
566 struct CastFunctor<O, I,
567                    enable_if_t<is_decimal_type<O>::value && is_floating_type<I>::value>> {
Execarrow::compute::internal::CastFunctor568   static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
569     const auto& options = checked_cast<const CastState*>(ctx->state())->options;
570     const auto& out_type = checked_cast<const O&>(*out->type());
571     const auto out_scale = out_type.scale();
572     const auto out_precision = out_type.precision();
573 
574     applicator::ScalarUnaryNotNullStateful<O, I, RealToDecimal> kernel(
575         RealToDecimal{out_scale, out_precision, options.allow_decimal_truncate});
576     return kernel.Exec(ctx, batch, out);
577   }
578 };
579 
580 // ----------------------------------------------------------------------
581 // Decimal to real
582 
583 struct DecimalToReal {
584   template <typename RealType, typename Arg0Value>
585   RealType Call(KernelContext*, const Arg0Value& val, Status*) const {
586     return val.template ToReal<RealType>(in_scale_);
587   }
588 
589   int32_t in_scale_;
590 };
591 
592 template <typename O, typename I>
593 struct CastFunctor<O, I,
594                    enable_if_t<is_floating_type<O>::value && is_decimal_type<I>::value>> {
Execarrow::compute::internal::CastFunctor595   static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
596     const auto& in_type = checked_cast<const I&>(*batch[0].type());
597     const auto in_scale = in_type.scale();
598 
599     applicator::ScalarUnaryNotNullStateful<O, I, DecimalToReal> kernel(
600         DecimalToReal{in_scale});
601     return kernel.Exec(ctx, batch, out);
602   }
603 };
604 
605 // ----------------------------------------------------------------------
606 // Top-level kernel instantiation
607 
608 namespace {
609 
610 template <typename OutType>
AddCommonNumberCasts(const std::shared_ptr<DataType> & out_ty,CastFunction * func)611 void AddCommonNumberCasts(const std::shared_ptr<DataType>& out_ty, CastFunction* func) {
612   AddCommonCasts(out_ty->id(), out_ty, func);
613 
614   // Cast from boolean to number
615   DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty,
616                             CastFunctor<OutType, BooleanType>::Exec));
617 
618   // Cast from other strings
619   for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
620     auto exec = GenerateVarBinaryBase<CastFunctor, OutType>(*in_ty);
621     DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, exec));
622   }
623 }
624 
625 template <typename OutType>
GetCastToInteger(std::string name)626 std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
627   auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id);
628   auto out_ty = TypeTraits<OutType>::type_singleton();
629 
630   for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
631     DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToInteger));
632   }
633 
634   // Cast from floating point
635   for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) {
636     DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger));
637   }
638 
639   // From other numbers to integer
640   AddCommonNumberCasts<OutType>(out_ty, func.get());
641 
642   // From decimal to integer
643   DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
644                             CastFunctor<OutType, Decimal128Type>::Exec));
645   DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
646                             CastFunctor<OutType, Decimal256Type>::Exec));
647   return func;
648 }
649 
650 template <typename OutType>
GetCastToFloating(std::string name)651 std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
652   auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id);
653   auto out_ty = TypeTraits<OutType>::type_singleton();
654 
655   // Casts from integer to floating point
656   for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
657     DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToFloating));
658   }
659 
660   // Cast from floating point
661   for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) {
662     DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating));
663   }
664 
665   // From other numbers to floating point
666   AddCommonNumberCasts<OutType>(out_ty, func.get());
667 
668   // From decimal to floating point
669   DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
670                             CastFunctor<OutType, Decimal128Type>::Exec));
671   DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
672                             CastFunctor<OutType, Decimal256Type>::Exec));
673   return func;
674 }
675 
GetCastToDecimal128()676 std::shared_ptr<CastFunction> GetCastToDecimal128() {
677   OutputType sig_out_ty(ResolveOutputFromOptions);
678 
679   auto func = std::make_shared<CastFunction>("cast_decimal", Type::DECIMAL128);
680   AddCommonCasts(Type::DECIMAL128, sig_out_ty, func.get());
681 
682   // Cast from floating point
683   DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
684                             CastFunctor<Decimal128Type, FloatType>::Exec));
685   DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
686                             CastFunctor<Decimal128Type, DoubleType>::Exec));
687 
688   // Cast from integer
689   for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
690     auto exec = GenerateInteger<CastFunctor, Decimal128Type>(in_ty->id());
691     DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
692   }
693 
694   // Cast from other decimal
695   auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
696   // We resolve the output type of this kernel from the CastOptions
697   DCHECK_OK(
698       func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
699   exec = CastFunctor<Decimal128Type, Decimal256Type>::Exec;
700   DCHECK_OK(
701       func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
702   return func;
703 }
704 
GetCastToDecimal256()705 std::shared_ptr<CastFunction> GetCastToDecimal256() {
706   OutputType sig_out_ty(ResolveOutputFromOptions);
707 
708   auto func = std::make_shared<CastFunction>("cast_decimal256", Type::DECIMAL256);
709   AddCommonCasts(Type::DECIMAL256, sig_out_ty, func.get());
710 
711   // Cast from floating point
712   DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
713                             CastFunctor<Decimal256Type, FloatType>::Exec));
714   DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
715                             CastFunctor<Decimal256Type, DoubleType>::Exec));
716 
717   // Cast from integer
718   for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
719     auto exec = GenerateInteger<CastFunctor, Decimal256Type>(in_ty->id());
720     DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
721   }
722 
723   // Cast from other decimal
724   auto exec = CastFunctor<Decimal256Type, Decimal128Type>::Exec;
725   DCHECK_OK(
726       func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
727   exec = CastFunctor<Decimal256Type, Decimal256Type>::Exec;
728   DCHECK_OK(
729       func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
730   return func;
731 }
732 
733 }  // namespace
734 
GetNumericCasts()735 std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
736   std::vector<std::shared_ptr<CastFunction>> functions;
737 
738   // Make a cast to null that does not do much. Not sure why we need to be able
739   // to cast from dict<null> -> null but there are unit tests for it
740   auto cast_null = std::make_shared<CastFunction>("cast_null", Type::NA);
741   DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(),
742                                  OutputAllNull));
743   functions.push_back(cast_null);
744 
745   functions.push_back(GetCastToInteger<Int8Type>("cast_int8"));
746   functions.push_back(GetCastToInteger<Int16Type>("cast_int16"));
747 
748   auto cast_int32 = GetCastToInteger<Int32Type>("cast_int32");
749   // Convert DATE32 or TIME32 to INT32 zero copy
750   AddZeroCopyCast(Type::DATE32, date32(), int32(), cast_int32.get());
751   AddZeroCopyCast(Type::TIME32, InputType(Type::TIME32), int32(), cast_int32.get());
752   functions.push_back(cast_int32);
753 
754   auto cast_int64 = GetCastToInteger<Int64Type>("cast_int64");
755   // Convert DATE64, DURATION, TIMESTAMP, TIME64 to INT64 zero copy
756   AddZeroCopyCast(Type::DATE64, InputType(Type::DATE64), int64(), cast_int64.get());
757   AddZeroCopyCast(Type::DURATION, InputType(Type::DURATION), int64(), cast_int64.get());
758   AddZeroCopyCast(Type::TIMESTAMP, InputType(Type::TIMESTAMP), int64(), cast_int64.get());
759   AddZeroCopyCast(Type::TIME64, InputType(Type::TIME64), int64(), cast_int64.get());
760   functions.push_back(cast_int64);
761 
762   functions.push_back(GetCastToInteger<UInt8Type>("cast_uint8"));
763   functions.push_back(GetCastToInteger<UInt16Type>("cast_uint16"));
764   functions.push_back(GetCastToInteger<UInt32Type>("cast_uint32"));
765   functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64"));
766 
767   // HalfFloat is a bit brain-damaged for now
768   auto cast_half_float =
769       std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT);
770   AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get());
771   functions.push_back(cast_half_float);
772 
773   functions.push_back(GetCastToFloating<FloatType>("cast_float"));
774   functions.push_back(GetCastToFloating<DoubleType>("cast_double"));
775 
776   functions.push_back(GetCastToDecimal128());
777   functions.push_back(GetCastToDecimal256());
778 
779   return functions;
780 }
781 
782 }  // namespace internal
783 }  // namespace compute
784 }  // namespace arrow
785