1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 // Implementation of casting to integer, floating point, or decimal types
19
20 #include "arrow/array/builder_primitive.h"
21 #include "arrow/compute/kernels/common.h"
22 #include "arrow/compute/kernels/scalar_cast_internal.h"
23 #include "arrow/compute/kernels/util_internal.h"
24 #include "arrow/util/bit_block_counter.h"
25 #include "arrow/util/int_util.h"
26 #include "arrow/util/value_parsing.h"
27
28 namespace arrow {
29
30 using internal::BitBlockCount;
31 using internal::CheckIntegersInRange;
32 using internal::IntegersCanFit;
33 using internal::OptionalBitBlockCounter;
34 using internal::ParseValue;
35
36 namespace compute {
37 namespace internal {
38
CastIntegerToInteger(KernelContext * ctx,const ExecBatch & batch,Datum * out)39 Status CastIntegerToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
40 const auto& options = checked_cast<const CastState*>(ctx->state())->options;
41 if (!options.allow_int_overflow) {
42 RETURN_NOT_OK(IntegersCanFit(batch[0], *out->type()));
43 }
44 CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
45 return Status::OK();
46 }
47
CastFloatingToFloating(KernelContext *,const ExecBatch & batch,Datum * out)48 Status CastFloatingToFloating(KernelContext*, const ExecBatch& batch, Datum* out) {
49 CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
50 return Status::OK();
51 }
52
53 // ----------------------------------------------------------------------
54 // Implement fast safe floating point to integer cast
55
56 // InType is a floating point type we are planning to cast to integer
57 template <typename InType, typename OutType, typename InT = typename InType::c_type,
58 typename OutT = typename OutType::c_type>
59 ARROW_DISABLE_UBSAN("float-cast-overflow")
CheckFloatTruncation(const Datum & input,const Datum & output)60 Status CheckFloatTruncation(const Datum& input, const Datum& output) {
61 auto WasTruncated = [&](OutT out_val, InT in_val) -> bool {
62 return static_cast<InT>(out_val) != in_val;
63 };
64 auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool {
65 return is_valid && static_cast<InT>(out_val) != in_val;
66 };
67 auto GetErrorMessage = [&](InT val) {
68 return Status::Invalid("Float value ", val, " was truncated converting to ",
69 *output.type());
70 };
71
72 if (input.kind() == Datum::SCALAR) {
73 DCHECK_EQ(output.kind(), Datum::SCALAR);
74 const auto& in_scalar = input.scalar_as<typename TypeTraits<InType>::ScalarType>();
75 const auto& out_scalar = output.scalar_as<typename TypeTraits<OutType>::ScalarType>();
76 if (WasTruncatedMaybeNull(out_scalar.value, in_scalar.value, out_scalar.is_valid)) {
77 return GetErrorMessage(in_scalar.value);
78 }
79 return Status::OK();
80 }
81
82 const ArrayData& in_array = *input.array();
83 const ArrayData& out_array = *output.array();
84
85 const InT* in_data = in_array.GetValues<InT>(1);
86 const OutT* out_data = out_array.GetValues<OutT>(1);
87
88 const uint8_t* bitmap = nullptr;
89 if (in_array.buffers[0]) {
90 bitmap = in_array.buffers[0]->data();
91 }
92 OptionalBitBlockCounter bit_counter(bitmap, in_array.offset, in_array.length);
93 int64_t position = 0;
94 int64_t offset_position = in_array.offset;
95 while (position < in_array.length) {
96 BitBlockCount block = bit_counter.NextBlock();
97 bool block_out_of_bounds = false;
98 if (block.popcount == block.length) {
99 // Fast path: branchless
100 for (int64_t i = 0; i < block.length; ++i) {
101 block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]);
102 }
103 } else if (block.popcount > 0) {
104 // Indices have nulls, must only boundscheck non-null values
105 for (int64_t i = 0; i < block.length; ++i) {
106 block_out_of_bounds |= WasTruncatedMaybeNull(
107 out_data[i], in_data[i], BitUtil::GetBit(bitmap, offset_position + i));
108 }
109 }
110 if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
111 if (in_array.GetNullCount() > 0) {
112 for (int64_t i = 0; i < block.length; ++i) {
113 if (WasTruncatedMaybeNull(out_data[i], in_data[i],
114 BitUtil::GetBit(bitmap, offset_position + i))) {
115 return GetErrorMessage(in_data[i]);
116 }
117 }
118 } else {
119 for (int64_t i = 0; i < block.length; ++i) {
120 if (WasTruncated(out_data[i], in_data[i])) {
121 return GetErrorMessage(in_data[i]);
122 }
123 }
124 }
125 }
126 in_data += block.length;
127 out_data += block.length;
128 position += block.length;
129 offset_position += block.length;
130 }
131 return Status::OK();
132 }
133
134 template <typename InType>
CheckFloatToIntTruncationImpl(const Datum & input,const Datum & output)135 Status CheckFloatToIntTruncationImpl(const Datum& input, const Datum& output) {
136 switch (output.type()->id()) {
137 case Type::INT8:
138 return CheckFloatTruncation<InType, Int8Type>(input, output);
139 case Type::INT16:
140 return CheckFloatTruncation<InType, Int16Type>(input, output);
141 case Type::INT32:
142 return CheckFloatTruncation<InType, Int32Type>(input, output);
143 case Type::INT64:
144 return CheckFloatTruncation<InType, Int64Type>(input, output);
145 case Type::UINT8:
146 return CheckFloatTruncation<InType, UInt8Type>(input, output);
147 case Type::UINT16:
148 return CheckFloatTruncation<InType, UInt16Type>(input, output);
149 case Type::UINT32:
150 return CheckFloatTruncation<InType, UInt32Type>(input, output);
151 case Type::UINT64:
152 return CheckFloatTruncation<InType, UInt64Type>(input, output);
153 default:
154 break;
155 }
156 DCHECK(false);
157 return Status::OK();
158 }
159
CheckFloatToIntTruncation(const Datum & input,const Datum & output)160 Status CheckFloatToIntTruncation(const Datum& input, const Datum& output) {
161 switch (input.type()->id()) {
162 case Type::FLOAT:
163 return CheckFloatToIntTruncationImpl<FloatType>(input, output);
164 case Type::DOUBLE:
165 return CheckFloatToIntTruncationImpl<DoubleType>(input, output);
166 default:
167 break;
168 }
169 DCHECK(false);
170 return Status::OK();
171 }
172
CastFloatingToInteger(KernelContext * ctx,const ExecBatch & batch,Datum * out)173 Status CastFloatingToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
174 const auto& options = checked_cast<const CastState*>(ctx->state())->options;
175 CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
176 if (!options.allow_float_truncate) {
177 RETURN_NOT_OK(CheckFloatToIntTruncation(batch[0], *out));
178 }
179 return Status::OK();
180 }
181
182 // ----------------------------------------------------------------------
183 // Implement fast integer to floating point cast
184
185 // These are the limits for exact representation of whole numbers in floating
186 // point numbers
187 template <typename T>
188 struct FloatingIntegerBound {};
189
190 template <>
191 struct FloatingIntegerBound<float> {
192 static const int64_t value = 1LL << 24;
193 };
194
195 template <>
196 struct FloatingIntegerBound<double> {
197 static const int64_t value = 1LL << 53;
198 };
199
200 template <typename InType, typename OutType, typename InT = typename InType::c_type,
201 typename OutT = typename OutType::c_type,
202 bool IsSigned = is_signed_integer_type<InType>::value>
CheckIntegerFloatTruncateImpl(const Datum & input)203 Status CheckIntegerFloatTruncateImpl(const Datum& input) {
204 using InScalarType = typename TypeTraits<InType>::ScalarType;
205 const int64_t limit = FloatingIntegerBound<OutT>::value;
206 InScalarType bound_lower(IsSigned ? -limit : 0);
207 InScalarType bound_upper(limit);
208 return CheckIntegersInRange(input, bound_lower, bound_upper);
209 }
210
CheckForIntegerToFloatingTruncation(const Datum & input,Type::type out_type)211 Status CheckForIntegerToFloatingTruncation(const Datum& input, Type::type out_type) {
212 switch (input.type()->id()) {
213 // Small integers are all exactly representable as whole numbers
214 case Type::INT8:
215 case Type::INT16:
216 case Type::UINT8:
217 case Type::UINT16:
218 return Status::OK();
219 case Type::INT32: {
220 if (out_type == Type::DOUBLE) {
221 return Status::OK();
222 }
223 return CheckIntegerFloatTruncateImpl<Int32Type, FloatType>(input);
224 }
225 case Type::UINT32: {
226 if (out_type == Type::DOUBLE) {
227 return Status::OK();
228 }
229 return CheckIntegerFloatTruncateImpl<UInt32Type, FloatType>(input);
230 }
231 case Type::INT64: {
232 if (out_type == Type::FLOAT) {
233 return CheckIntegerFloatTruncateImpl<Int64Type, FloatType>(input);
234 } else {
235 return CheckIntegerFloatTruncateImpl<Int64Type, DoubleType>(input);
236 }
237 }
238 case Type::UINT64: {
239 if (out_type == Type::FLOAT) {
240 return CheckIntegerFloatTruncateImpl<UInt64Type, FloatType>(input);
241 } else {
242 return CheckIntegerFloatTruncateImpl<UInt64Type, DoubleType>(input);
243 }
244 }
245 default:
246 break;
247 }
248 DCHECK(false);
249 return Status::OK();
250 }
251
CastIntegerToFloating(KernelContext * ctx,const ExecBatch & batch,Datum * out)252 Status CastIntegerToFloating(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
253 const auto& options = checked_cast<const CastState*>(ctx->state())->options;
254 Type::type out_type = out->type()->id();
255 if (!options.allow_float_truncate) {
256 RETURN_NOT_OK(CheckForIntegerToFloatingTruncation(batch[0], out_type));
257 }
258 CastNumberToNumberUnsafe(batch[0].type()->id(), out_type, batch[0], out);
259 return Status::OK();
260 }
261
262 // ----------------------------------------------------------------------
263 // Boolean to number
264
265 struct BooleanToNumber {
266 template <typename OutValue, typename Arg0Value>
Callarrow::compute::internal::BooleanToNumber267 static OutValue Call(KernelContext*, Arg0Value val, Status*) {
268 constexpr auto kOne = static_cast<OutValue>(1);
269 constexpr auto kZero = static_cast<OutValue>(0);
270 return val ? kOne : kZero;
271 }
272 };
273
274 template <typename O>
275 struct CastFunctor<O, BooleanType, enable_if_number<O>> {
Execarrow::compute::internal::CastFunctor276 static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
277 return applicator::ScalarUnary<O, BooleanType, BooleanToNumber>::Exec(ctx, batch,
278 out);
279 }
280 };
281
282 // ----------------------------------------------------------------------
283 // String to number
284
285 template <typename OutType>
286 struct ParseString {
287 template <typename OutValue, typename Arg0Value>
288 OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
289 OutValue result = OutValue(0);
290 if (ARROW_PREDICT_FALSE(!ParseValue<OutType>(val.data(), val.size(), &result))) {
291 *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ",
292 TypeTraits<OutType>::type_singleton()->ToString());
293 }
294 return result;
295 }
296 };
297
298 template <typename O, typename I>
299 struct CastFunctor<O, I, enable_if_base_binary<I>> {
Execarrow::compute::internal::CastFunctor300 static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
301 return applicator::ScalarUnaryNotNull<O, I, ParseString<O>>::Exec(ctx, batch, out);
302 }
303 };
304
305 // ----------------------------------------------------------------------
306 // Decimal to integer
307
308 struct DecimalToIntegerMixin {
309 template <typename OutValue, typename Arg0Value>
310 OutValue ToInteger(KernelContext* ctx, const Arg0Value& val, Status* st) const {
311 constexpr auto min_value = std::numeric_limits<OutValue>::min();
312 constexpr auto max_value = std::numeric_limits<OutValue>::max();
313
314 if (!allow_int_overflow_ && ARROW_PREDICT_FALSE(val < min_value || val > max_value)) {
315 *st = Status::Invalid("Integer value out of bounds");
316 return OutValue{}; // Zero
317 } else {
318 return static_cast<OutValue>(val.low_bits());
319 }
320 }
321
DecimalToIntegerMixinarrow::compute::internal::DecimalToIntegerMixin322 DecimalToIntegerMixin(int32_t in_scale, bool allow_int_overflow)
323 : in_scale_(in_scale), allow_int_overflow_(allow_int_overflow) {}
324
325 int32_t in_scale_;
326 bool allow_int_overflow_;
327 };
328
329 struct UnsafeUpscaleDecimalToInteger : public DecimalToIntegerMixin {
330 using DecimalToIntegerMixin::DecimalToIntegerMixin;
331
332 template <typename OutValue, typename Arg0Value>
333 OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
334 return ToInteger<OutValue>(ctx, val.IncreaseScaleBy(-in_scale_), st);
335 }
336 };
337
338 struct UnsafeDownscaleDecimalToInteger : public DecimalToIntegerMixin {
339 using DecimalToIntegerMixin::DecimalToIntegerMixin;
340
341 template <typename OutValue, typename Arg0Value>
342 OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
343 return ToInteger<OutValue>(ctx, val.ReduceScaleBy(in_scale_, false), st);
344 }
345 };
346
347 struct SafeRescaleDecimalToInteger : public DecimalToIntegerMixin {
348 using DecimalToIntegerMixin::DecimalToIntegerMixin;
349
350 template <typename OutValue, typename Arg0Value>
351 OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
352 auto result = val.Rescale(in_scale_, 0);
353 if (ARROW_PREDICT_FALSE(!result.ok())) {
354 *st = result.status();
355 return OutValue{}; // Zero
356 } else {
357 return ToInteger<OutValue>(ctx, *result, st);
358 }
359 }
360 };
361
362 template <typename O, typename I>
363 struct CastFunctor<O, I,
364 enable_if_t<is_integer_type<O>::value && is_decimal_type<I>::value>> {
365 using out_type = typename O::c_type;
366
Execarrow::compute::internal::CastFunctor367 static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
368 const auto& options = checked_cast<const CastState*>(ctx->state())->options;
369
370 const auto& in_type_inst = checked_cast<const I&>(*batch[0].type());
371 const auto in_scale = in_type_inst.scale();
372
373 if (options.allow_decimal_truncate) {
374 if (in_scale < 0) {
375 // Unsafe upscale
376 applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimalToInteger>
377 kernel(UnsafeUpscaleDecimalToInteger{in_scale, options.allow_int_overflow});
378 return kernel.Exec(ctx, batch, out);
379 } else {
380 // Unsafe downscale
381 applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimalToInteger>
382 kernel(UnsafeDownscaleDecimalToInteger{in_scale, options.allow_int_overflow});
383 return kernel.Exec(ctx, batch, out);
384 }
385 } else {
386 // Safe rescale
387 applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimalToInteger> kernel(
388 SafeRescaleDecimalToInteger{in_scale, options.allow_int_overflow});
389 return kernel.Exec(ctx, batch, out);
390 }
391 }
392 };
393
394 // ----------------------------------------------------------------------
395 // Integer to decimal
396
397 struct IntegerToDecimal {
398 template <typename OutValue, typename IntegerType>
399 OutValue Call(KernelContext*, IntegerType val, Status* st) const {
400 auto maybe_decimal = OutValue(val).Rescale(0, out_scale_);
401 if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) {
402 return maybe_decimal.MoveValueUnsafe();
403 }
404 *st = maybe_decimal.status();
405 return OutValue{};
406 }
407
408 int32_t out_scale_;
409 };
410
411 template <typename O, typename I>
412 struct CastFunctor<O, I,
413 enable_if_t<is_decimal_type<O>::value && is_integer_type<I>::value>> {
Execarrow::compute::internal::CastFunctor414 static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
415 const auto& out_type = checked_cast<const O&>(*out->type());
416 const auto out_scale = out_type.scale();
417 const auto out_precision = out_type.precision();
418
419 // verify precision and scale
420 if (out_scale < 0) {
421 return Status::Invalid("Scale must be non-negative");
422 }
423 ARROW_ASSIGN_OR_RAISE(int32_t precision, MaxDecimalDigitsForInteger(I::type_id));
424 precision += out_scale;
425 if (out_precision < precision) {
426 return Status::Invalid(
427 "Precision is not great enough for the result. "
428 "It should be at least ",
429 precision);
430 }
431
432 applicator::ScalarUnaryNotNullStateful<O, I, IntegerToDecimal> kernel(
433 IntegerToDecimal{out_scale});
434 return kernel.Exec(ctx, batch, out);
435 }
436 };
437
438 // ----------------------------------------------------------------------
439 // Decimal to decimal
440
441 // Helper that converts the input and output decimals
442 // For instance, Decimal128 -> Decimal256 requires converting, then scaling
443 // Decimal256 -> Decimal128 requires scaling, then truncating
444 template <typename OutDecimal, typename InDecimal>
445 struct DecimalConversions {};
446
447 template <typename InDecimal>
448 struct DecimalConversions<Decimal256, InDecimal> {
449 // Convert then scale
ConvertInputarrow::compute::internal::DecimalConversions450 static Decimal256 ConvertInput(InDecimal&& val) { return Decimal256(val); }
ConvertOutputarrow::compute::internal::DecimalConversions451 static Decimal256 ConvertOutput(Decimal256&& val) { return val; }
452 };
453
454 template <>
455 struct DecimalConversions<Decimal128, Decimal256> {
456 // Scale then truncate
ConvertInputarrow::compute::internal::DecimalConversions457 static Decimal256 ConvertInput(Decimal256&& val) { return val; }
ConvertOutputarrow::compute::internal::DecimalConversions458 static Decimal128 ConvertOutput(Decimal256&& val) {
459 const auto array_le = BitUtil::LittleEndianArray::Make(val.native_endian_array());
460 return Decimal128(array_le[1], array_le[0]);
461 }
462 };
463
464 template <>
465 struct DecimalConversions<Decimal128, Decimal128> {
ConvertInputarrow::compute::internal::DecimalConversions466 static Decimal128 ConvertInput(Decimal128&& val) { return val; }
ConvertOutputarrow::compute::internal::DecimalConversions467 static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
468 };
469
470 struct UnsafeUpscaleDecimal {
471 template <typename OutValue, typename Arg0Value>
472 OutValue Call(KernelContext*, Arg0Value val, Status*) const {
473 using Conv = DecimalConversions<OutValue, Arg0Value>;
474 return Conv::ConvertOutput(Conv::ConvertInput(std::move(val)).IncreaseScaleBy(by_));
475 }
476 int32_t by_;
477 };
478
479 struct UnsafeDownscaleDecimal {
480 template <typename OutValue, typename Arg0Value>
481 OutValue Call(KernelContext*, Arg0Value val, Status*) const {
482 using Conv = DecimalConversions<OutValue, Arg0Value>;
483 return Conv::ConvertOutput(
484 Conv::ConvertInput(std::move(val)).ReduceScaleBy(by_, false));
485 }
486 int32_t by_;
487 };
488
489 struct SafeRescaleDecimal {
490 template <typename OutValue, typename Arg0Value>
491 OutValue Call(KernelContext*, Arg0Value val, Status* st) const {
492 using Conv = DecimalConversions<OutValue, Arg0Value>;
493 auto maybe_rescaled =
494 Conv::ConvertInput(std::move(val)).Rescale(in_scale_, out_scale_);
495 if (ARROW_PREDICT_FALSE(!maybe_rescaled.ok())) {
496 *st = maybe_rescaled.status();
497 return {}; // Zero
498 }
499
500 if (ARROW_PREDICT_TRUE(maybe_rescaled->FitsInPrecision(out_precision_))) {
501 return Conv::ConvertOutput(maybe_rescaled.MoveValueUnsafe());
502 }
503
504 *st = Status::Invalid("Decimal value does not fit in precision ", out_precision_);
505 return {}; // Zero
506 }
507
508 int32_t out_scale_, out_precision_, in_scale_;
509 };
510
511 template <typename O, typename I>
512 struct CastFunctor<O, I,
513 enable_if_t<is_decimal_type<O>::value && is_decimal_type<I>::value>> {
Execarrow::compute::internal::CastFunctor514 static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
515 const auto& options = checked_cast<const CastState*>(ctx->state())->options;
516
517 const auto& in_type = checked_cast<const I&>(*batch[0].type());
518 const auto& out_type = checked_cast<const O&>(*out->type());
519 const auto in_scale = in_type.scale();
520 const auto out_scale = out_type.scale();
521
522 if (options.allow_decimal_truncate) {
523 if (in_scale < out_scale) {
524 // Unsafe upscale
525 applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimal> kernel(
526 UnsafeUpscaleDecimal{out_scale - in_scale});
527 return kernel.Exec(ctx, batch, out);
528 } else {
529 // Unsafe downscale
530 applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimal> kernel(
531 UnsafeDownscaleDecimal{in_scale - out_scale});
532 return kernel.Exec(ctx, batch, out);
533 }
534 }
535
536 // Safe rescale
537 applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimal> kernel(
538 SafeRescaleDecimal{out_scale, out_type.precision(), in_scale});
539 return kernel.Exec(ctx, batch, out);
540 }
541 };
542
543 // ----------------------------------------------------------------------
544 // Real to decimal
545
546 struct RealToDecimal {
547 template <typename OutValue, typename RealType>
548 OutValue Call(KernelContext*, RealType val, Status* st) const {
549 auto maybe_decimal = OutValue::FromReal(val, out_precision_, out_scale_);
550
551 if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) {
552 return maybe_decimal.MoveValueUnsafe();
553 }
554
555 if (!allow_truncate_) {
556 *st = maybe_decimal.status();
557 }
558 return {}; // Zero
559 }
560
561 int32_t out_scale_, out_precision_;
562 bool allow_truncate_;
563 };
564
565 template <typename O, typename I>
566 struct CastFunctor<O, I,
567 enable_if_t<is_decimal_type<O>::value && is_floating_type<I>::value>> {
Execarrow::compute::internal::CastFunctor568 static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
569 const auto& options = checked_cast<const CastState*>(ctx->state())->options;
570 const auto& out_type = checked_cast<const O&>(*out->type());
571 const auto out_scale = out_type.scale();
572 const auto out_precision = out_type.precision();
573
574 applicator::ScalarUnaryNotNullStateful<O, I, RealToDecimal> kernel(
575 RealToDecimal{out_scale, out_precision, options.allow_decimal_truncate});
576 return kernel.Exec(ctx, batch, out);
577 }
578 };
579
580 // ----------------------------------------------------------------------
581 // Decimal to real
582
583 struct DecimalToReal {
584 template <typename RealType, typename Arg0Value>
585 RealType Call(KernelContext*, const Arg0Value& val, Status*) const {
586 return val.template ToReal<RealType>(in_scale_);
587 }
588
589 int32_t in_scale_;
590 };
591
592 template <typename O, typename I>
593 struct CastFunctor<O, I,
594 enable_if_t<is_floating_type<O>::value && is_decimal_type<I>::value>> {
Execarrow::compute::internal::CastFunctor595 static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
596 const auto& in_type = checked_cast<const I&>(*batch[0].type());
597 const auto in_scale = in_type.scale();
598
599 applicator::ScalarUnaryNotNullStateful<O, I, DecimalToReal> kernel(
600 DecimalToReal{in_scale});
601 return kernel.Exec(ctx, batch, out);
602 }
603 };
604
605 // ----------------------------------------------------------------------
606 // Top-level kernel instantiation
607
608 namespace {
609
610 template <typename OutType>
AddCommonNumberCasts(const std::shared_ptr<DataType> & out_ty,CastFunction * func)611 void AddCommonNumberCasts(const std::shared_ptr<DataType>& out_ty, CastFunction* func) {
612 AddCommonCasts(out_ty->id(), out_ty, func);
613
614 // Cast from boolean to number
615 DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty,
616 CastFunctor<OutType, BooleanType>::Exec));
617
618 // Cast from other strings
619 for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
620 auto exec = GenerateVarBinaryBase<CastFunctor, OutType>(*in_ty);
621 DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, exec));
622 }
623 }
624
625 template <typename OutType>
GetCastToInteger(std::string name)626 std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
627 auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id);
628 auto out_ty = TypeTraits<OutType>::type_singleton();
629
630 for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
631 DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToInteger));
632 }
633
634 // Cast from floating point
635 for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) {
636 DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger));
637 }
638
639 // From other numbers to integer
640 AddCommonNumberCasts<OutType>(out_ty, func.get());
641
642 // From decimal to integer
643 DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
644 CastFunctor<OutType, Decimal128Type>::Exec));
645 DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
646 CastFunctor<OutType, Decimal256Type>::Exec));
647 return func;
648 }
649
650 template <typename OutType>
GetCastToFloating(std::string name)651 std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
652 auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id);
653 auto out_ty = TypeTraits<OutType>::type_singleton();
654
655 // Casts from integer to floating point
656 for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
657 DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToFloating));
658 }
659
660 // Cast from floating point
661 for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) {
662 DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating));
663 }
664
665 // From other numbers to floating point
666 AddCommonNumberCasts<OutType>(out_ty, func.get());
667
668 // From decimal to floating point
669 DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
670 CastFunctor<OutType, Decimal128Type>::Exec));
671 DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
672 CastFunctor<OutType, Decimal256Type>::Exec));
673 return func;
674 }
675
GetCastToDecimal128()676 std::shared_ptr<CastFunction> GetCastToDecimal128() {
677 OutputType sig_out_ty(ResolveOutputFromOptions);
678
679 auto func = std::make_shared<CastFunction>("cast_decimal", Type::DECIMAL128);
680 AddCommonCasts(Type::DECIMAL128, sig_out_ty, func.get());
681
682 // Cast from floating point
683 DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
684 CastFunctor<Decimal128Type, FloatType>::Exec));
685 DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
686 CastFunctor<Decimal128Type, DoubleType>::Exec));
687
688 // Cast from integer
689 for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
690 auto exec = GenerateInteger<CastFunctor, Decimal128Type>(in_ty->id());
691 DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
692 }
693
694 // Cast from other decimal
695 auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
696 // We resolve the output type of this kernel from the CastOptions
697 DCHECK_OK(
698 func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
699 exec = CastFunctor<Decimal128Type, Decimal256Type>::Exec;
700 DCHECK_OK(
701 func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
702 return func;
703 }
704
GetCastToDecimal256()705 std::shared_ptr<CastFunction> GetCastToDecimal256() {
706 OutputType sig_out_ty(ResolveOutputFromOptions);
707
708 auto func = std::make_shared<CastFunction>("cast_decimal256", Type::DECIMAL256);
709 AddCommonCasts(Type::DECIMAL256, sig_out_ty, func.get());
710
711 // Cast from floating point
712 DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
713 CastFunctor<Decimal256Type, FloatType>::Exec));
714 DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
715 CastFunctor<Decimal256Type, DoubleType>::Exec));
716
717 // Cast from integer
718 for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
719 auto exec = GenerateInteger<CastFunctor, Decimal256Type>(in_ty->id());
720 DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
721 }
722
723 // Cast from other decimal
724 auto exec = CastFunctor<Decimal256Type, Decimal128Type>::Exec;
725 DCHECK_OK(
726 func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
727 exec = CastFunctor<Decimal256Type, Decimal256Type>::Exec;
728 DCHECK_OK(
729 func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
730 return func;
731 }
732
733 } // namespace
734
GetNumericCasts()735 std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
736 std::vector<std::shared_ptr<CastFunction>> functions;
737
738 // Make a cast to null that does not do much. Not sure why we need to be able
739 // to cast from dict<null> -> null but there are unit tests for it
740 auto cast_null = std::make_shared<CastFunction>("cast_null", Type::NA);
741 DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(),
742 OutputAllNull));
743 functions.push_back(cast_null);
744
745 functions.push_back(GetCastToInteger<Int8Type>("cast_int8"));
746 functions.push_back(GetCastToInteger<Int16Type>("cast_int16"));
747
748 auto cast_int32 = GetCastToInteger<Int32Type>("cast_int32");
749 // Convert DATE32 or TIME32 to INT32 zero copy
750 AddZeroCopyCast(Type::DATE32, date32(), int32(), cast_int32.get());
751 AddZeroCopyCast(Type::TIME32, InputType(Type::TIME32), int32(), cast_int32.get());
752 functions.push_back(cast_int32);
753
754 auto cast_int64 = GetCastToInteger<Int64Type>("cast_int64");
755 // Convert DATE64, DURATION, TIMESTAMP, TIME64 to INT64 zero copy
756 AddZeroCopyCast(Type::DATE64, InputType(Type::DATE64), int64(), cast_int64.get());
757 AddZeroCopyCast(Type::DURATION, InputType(Type::DURATION), int64(), cast_int64.get());
758 AddZeroCopyCast(Type::TIMESTAMP, InputType(Type::TIMESTAMP), int64(), cast_int64.get());
759 AddZeroCopyCast(Type::TIME64, InputType(Type::TIME64), int64(), cast_int64.get());
760 functions.push_back(cast_int64);
761
762 functions.push_back(GetCastToInteger<UInt8Type>("cast_uint8"));
763 functions.push_back(GetCastToInteger<UInt16Type>("cast_uint16"));
764 functions.push_back(GetCastToInteger<UInt32Type>("cast_uint32"));
765 functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64"));
766
767 // HalfFloat is a bit brain-damaged for now
768 auto cast_half_float =
769 std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT);
770 AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get());
771 functions.push_back(cast_half_float);
772
773 functions.push_back(GetCastToFloating<FloatType>("cast_float"));
774 functions.push_back(GetCastToFloating<DoubleType>("cast_double"));
775
776 functions.push_back(GetCastToDecimal128());
777 functions.push_back(GetCastToDecimal256());
778
779 return functions;
780 }
781
782 } // namespace internal
783 } // namespace compute
784 } // namespace arrow
785