1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 // Functions for comparing Arrow data structures
19
20 #include "arrow/compare.h"
21
22 #include <climits>
23 #include <cmath>
24 #include <cstdint>
25 #include <cstring>
26 #include <memory>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 #include <vector>
31
32 #include "arrow/array.h"
33 #include "arrow/array/diff.h"
34 #include "arrow/buffer.h"
35 #include "arrow/scalar.h"
36 #include "arrow/sparse_tensor.h"
37 #include "arrow/status.h"
38 #include "arrow/tensor.h"
39 #include "arrow/type.h"
40 #include "arrow/type_traits.h"
41 #include "arrow/util/bit_util.h"
42 #include "arrow/util/checked_cast.h"
43 #include "arrow/util/logging.h"
44 #include "arrow/util/macros.h"
45 #include "arrow/util/memory.h"
46 #include "arrow/visitor_inline.h"
47
48 namespace arrow {
49
50 using internal::BitmapEquals;
51 using internal::checked_cast;
52
53 // ----------------------------------------------------------------------
54 // Public method implementations
55
56 namespace {
57
58 // These helper functions assume we already checked the arrays have equal
59 // sizes and null bitmaps.
60
61 template <typename ArrowType, typename EqualityFunc>
BaseFloatingEquals(const NumericArray<ArrowType> & left,const NumericArray<ArrowType> & right,EqualityFunc && equals)62 inline bool BaseFloatingEquals(const NumericArray<ArrowType>& left,
63 const NumericArray<ArrowType>& right,
64 EqualityFunc&& equals) {
65 using T = typename ArrowType::c_type;
66
67 const T* left_data = left.raw_values();
68 const T* right_data = right.raw_values();
69
70 if (left.null_count() > 0) {
71 for (int64_t i = 0; i < left.length(); ++i) {
72 if (left.IsNull(i)) continue;
73 if (!equals(left_data[i], right_data[i])) {
74 return false;
75 }
76 }
77 } else {
78 for (int64_t i = 0; i < left.length(); ++i) {
79 if (!equals(left_data[i], right_data[i])) {
80 return false;
81 }
82 }
83 }
84 return true;
85 }
86
87 template <typename ArrowType>
FloatingEquals(const NumericArray<ArrowType> & left,const NumericArray<ArrowType> & right,const EqualOptions & opts)88 inline bool FloatingEquals(const NumericArray<ArrowType>& left,
89 const NumericArray<ArrowType>& right,
90 const EqualOptions& opts) {
91 using T = typename ArrowType::c_type;
92
93 if (opts.nans_equal()) {
94 return BaseFloatingEquals<ArrowType>(left, right, [](T x, T y) -> bool {
95 return (x == y) || (std::isnan(x) && std::isnan(y));
96 });
97 } else {
98 return BaseFloatingEquals<ArrowType>(left, right,
99 [](T x, T y) -> bool { return x == y; });
100 }
101 }
102
103 template <typename ArrowType>
FloatingApproxEquals(const NumericArray<ArrowType> & left,const NumericArray<ArrowType> & right,const EqualOptions & opts)104 inline bool FloatingApproxEquals(const NumericArray<ArrowType>& left,
105 const NumericArray<ArrowType>& right,
106 const EqualOptions& opts) {
107 using T = typename ArrowType::c_type;
108 const T epsilon = static_cast<T>(opts.atol());
109
110 if (opts.nans_equal()) {
111 return BaseFloatingEquals<ArrowType>(left, right, [epsilon](T x, T y) -> bool {
112 return (fabs(x - y) <= epsilon) || (std::isnan(x) && std::isnan(y));
113 });
114 } else {
115 return BaseFloatingEquals<ArrowType>(
116 left, right, [epsilon](T x, T y) -> bool { return fabs(x - y) <= epsilon; });
117 }
118 }
119
120 // RangeEqualsVisitor assumes the range sizes are equal
121
122 class RangeEqualsVisitor {
123 public:
RangeEqualsVisitor(const Array & right,int64_t left_start_idx,int64_t left_end_idx,int64_t right_start_idx)124 RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx,
125 int64_t right_start_idx)
126 : right_(right),
127 left_start_idx_(left_start_idx),
128 left_end_idx_(left_end_idx),
129 right_start_idx_(right_start_idx),
130 result_(false) {}
131
132 template <typename ArrayType>
CompareValues(const ArrayType & left)133 inline Status CompareValues(const ArrayType& left) {
134 const auto& right = checked_cast<const ArrayType&>(right_);
135
136 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
137 ++i, ++o_i) {
138 const bool is_null = left.IsNull(i);
139 if (is_null != right.IsNull(o_i) ||
140 (!is_null && left.Value(i) != right.Value(o_i))) {
141 result_ = false;
142 return Status::OK();
143 }
144 }
145 result_ = true;
146 return Status::OK();
147 }
148
149 template <typename BinaryArrayType>
CompareBinaryRange(const BinaryArrayType & left) const150 bool CompareBinaryRange(const BinaryArrayType& left) const {
151 const auto& right = checked_cast<const BinaryArrayType&>(right_);
152
153 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
154 ++i, ++o_i) {
155 const bool is_null = left.IsNull(i);
156 if (is_null != right.IsNull(o_i)) {
157 return false;
158 }
159 if (is_null) continue;
160 const auto begin_offset = left.value_offset(i);
161 const auto end_offset = left.value_offset(i + 1);
162 const auto right_begin_offset = right.value_offset(o_i);
163 const auto right_end_offset = right.value_offset(o_i + 1);
164 // Underlying can't be equal if the size isn't equal
165 if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
166 return false;
167 }
168
169 if (end_offset - begin_offset > 0 &&
170 std::memcmp(left.value_data()->data() + begin_offset,
171 right.value_data()->data() + right_begin_offset,
172 static_cast<size_t>(end_offset - begin_offset))) {
173 return false;
174 }
175 }
176 return true;
177 }
178
179 template <typename ListArrayType>
CompareLists(const ListArrayType & left)180 bool CompareLists(const ListArrayType& left) {
181 const auto& right = checked_cast<const ListArrayType&>(right_);
182
183 const std::shared_ptr<Array>& left_values = left.values();
184 const std::shared_ptr<Array>& right_values = right.values();
185
186 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
187 ++i, ++o_i) {
188 const bool is_null = left.IsNull(i);
189 if (is_null != right.IsNull(o_i)) {
190 return false;
191 }
192 if (is_null) continue;
193 const auto begin_offset = left.value_offset(i);
194 const auto end_offset = left.value_offset(i + 1);
195 const auto right_begin_offset = right.value_offset(o_i);
196 const auto right_end_offset = right.value_offset(o_i + 1);
197 // Underlying can't be equal if the size isn't equal
198 if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
199 return false;
200 }
201 if (!left_values->RangeEquals(begin_offset, end_offset, right_begin_offset,
202 right_values)) {
203 return false;
204 }
205 }
206 return true;
207 }
208
CompareStructs(const StructArray & left)209 bool CompareStructs(const StructArray& left) {
210 const auto& right = checked_cast<const StructArray&>(right_);
211 bool equal_fields = true;
212 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
213 ++i, ++o_i) {
214 if (left.IsNull(i) != right.IsNull(o_i)) {
215 return false;
216 }
217 if (left.IsNull(i)) continue;
218 for (int j = 0; j < left.num_fields(); ++j) {
219 // TODO: really we should be comparing stretches of non-null data rather
220 // than looking at one value at a time.
221 equal_fields = left.field(j)->RangeEquals(i, i + 1, o_i, right.field(j));
222 if (!equal_fields) {
223 return false;
224 }
225 }
226 }
227 return true;
228 }
229
CompareUnions(const UnionArray & left) const230 bool CompareUnions(const UnionArray& left) const {
231 const auto& right = checked_cast<const UnionArray&>(right_);
232
233 const UnionMode::type union_mode = left.mode();
234 if (union_mode != right.mode()) {
235 return false;
236 }
237
238 const auto& left_type = checked_cast<const UnionType&>(*left.type());
239
240 const std::vector<int>& child_ids = left_type.child_ids();
241
242 const int8_t* left_codes = left.raw_type_codes();
243 const int8_t* right_codes = right.raw_type_codes();
244
245 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
246 ++i, ++o_i) {
247 if (left.IsNull(i) != right.IsNull(o_i)) {
248 return false;
249 }
250 if (left.IsNull(i)) continue;
251 if (left_codes[i] != right_codes[o_i]) {
252 return false;
253 }
254
255 auto child_num = child_ids[left_codes[i]];
256
257 // TODO(wesm): really we should be comparing stretches of non-null data
258 // rather than looking at one value at a time.
259 if (union_mode == UnionMode::SPARSE) {
260 if (!left.field(child_num)->RangeEquals(i, i + 1, o_i, right.field(child_num))) {
261 return false;
262 }
263 } else {
264 const int32_t offset = left.raw_value_offsets()[i];
265 const int32_t o_offset = right.raw_value_offsets()[o_i];
266 if (!left.field(child_num)->RangeEquals(offset, offset + 1, o_offset,
267 right.field(child_num))) {
268 return false;
269 }
270 }
271 }
272 return true;
273 }
274
Visit(const BinaryArray & left)275 Status Visit(const BinaryArray& left) {
276 result_ = CompareBinaryRange(left);
277 return Status::OK();
278 }
279
Visit(const LargeBinaryArray & left)280 Status Visit(const LargeBinaryArray& left) {
281 result_ = CompareBinaryRange(left);
282 return Status::OK();
283 }
284
Visit(const FixedSizeBinaryArray & left)285 Status Visit(const FixedSizeBinaryArray& left) {
286 const auto& right = checked_cast<const FixedSizeBinaryArray&>(right_);
287
288 int32_t width = left.byte_width();
289
290 const uint8_t* left_data = nullptr;
291 const uint8_t* right_data = nullptr;
292
293 if (left.values()) {
294 left_data = left.raw_values();
295 }
296
297 if (right.values()) {
298 right_data = right.raw_values();
299 }
300
301 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
302 ++i, ++o_i) {
303 const bool is_null = left.IsNull(i);
304 if (is_null != right.IsNull(o_i)) {
305 result_ = false;
306 return Status::OK();
307 }
308 if (is_null) continue;
309
310 if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
311 result_ = false;
312 return Status::OK();
313 }
314 }
315 result_ = true;
316 return Status::OK();
317 }
318
Visit(const Decimal128Array & left)319 Status Visit(const Decimal128Array& left) {
320 return Visit(checked_cast<const FixedSizeBinaryArray&>(left));
321 }
322
Visit(const NullArray & left)323 Status Visit(const NullArray& left) {
324 ARROW_UNUSED(left);
325 result_ = true;
326 return Status::OK();
327 }
328
329 template <typename T>
Visit(const T & left)330 typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit(
331 const T& left) {
332 return CompareValues<T>(left);
333 }
334
Visit(const ListArray & left)335 Status Visit(const ListArray& left) {
336 result_ = CompareLists(left);
337 return Status::OK();
338 }
339
Visit(const LargeListArray & left)340 Status Visit(const LargeListArray& left) {
341 result_ = CompareLists(left);
342 return Status::OK();
343 }
344
Visit(const FixedSizeListArray & left)345 Status Visit(const FixedSizeListArray& left) {
346 const auto& right = checked_cast<const FixedSizeListArray&>(right_);
347 result_ = left.values()->RangeEquals(
348 left.value_offset(left_start_idx_), left.value_offset(left_end_idx_),
349 right.value_offset(right_start_idx_), right.values());
350 return Status::OK();
351 }
352
Visit(const StructArray & left)353 Status Visit(const StructArray& left) {
354 result_ = CompareStructs(left);
355 return Status::OK();
356 }
357
Visit(const UnionArray & left)358 Status Visit(const UnionArray& left) {
359 result_ = CompareUnions(left);
360 return Status::OK();
361 }
362
Visit(const DictionaryArray & left)363 Status Visit(const DictionaryArray& left) {
364 const auto& right = checked_cast<const DictionaryArray&>(right_);
365 if (!left.dictionary()->Equals(right.dictionary())) {
366 result_ = false;
367 return Status::OK();
368 }
369 result_ = left.indices()->RangeEquals(left_start_idx_, left_end_idx_,
370 right_start_idx_, right.indices());
371 return Status::OK();
372 }
373
Visit(const ExtensionArray & left)374 Status Visit(const ExtensionArray& left) {
375 result_ = (right_.type()->Equals(*left.type()) &&
376 ArrayRangeEquals(*left.storage(),
377 *static_cast<const ExtensionArray&>(right_).storage(),
378 left_start_idx_, left_end_idx_, right_start_idx_));
379 return Status::OK();
380 }
381
result() const382 bool result() const { return result_; }
383
384 protected:
385 const Array& right_;
386 int64_t left_start_idx_;
387 int64_t left_end_idx_;
388 int64_t right_start_idx_;
389
390 bool result_;
391 };
392
IsEqualPrimitive(const PrimitiveArray & left,const PrimitiveArray & right)393 static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& right) {
394 const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
395 const int byte_width = size_meta.bit_width() / CHAR_BIT;
396
397 const uint8_t* left_data = nullptr;
398 const uint8_t* right_data = nullptr;
399
400 if (left.values()) {
401 left_data = left.values()->data() + left.offset() * byte_width;
402 }
403
404 if (right.values()) {
405 right_data = right.values()->data() + right.offset() * byte_width;
406 }
407
408 if (byte_width == 0) {
409 // Special case 0-width data, as the data pointers may be null
410 for (int64_t i = 0; i < left.length(); ++i) {
411 if (left.IsNull(i) != right.IsNull(i)) {
412 return false;
413 }
414 }
415 return true;
416 } else if (left.null_count() > 0) {
417 for (int64_t i = 0; i < left.length(); ++i) {
418 const bool left_null = left.IsNull(i);
419 const bool right_null = right.IsNull(i);
420 if (left_null != right_null) {
421 return false;
422 }
423 if (!left_null && memcmp(left_data, right_data, byte_width) != 0) {
424 return false;
425 }
426 left_data += byte_width;
427 right_data += byte_width;
428 }
429 return true;
430 } else {
431 auto number_of_bytes_to_compare = static_cast<size_t>(byte_width * left.length());
432 return memcmp(left_data, right_data, number_of_bytes_to_compare) == 0;
433 }
434 }
435
436 // A bit confusing: ArrayEqualsVisitor inherits from RangeEqualsVisitor but
437 // doesn't share the same preconditions.
438 // When RangeEqualsVisitor is called, we only know the range sizes equal.
439 // When ArrayEqualsVisitor is called, we know the sizes and null bitmaps are equal.
440
441 class ArrayEqualsVisitor : public RangeEqualsVisitor {
442 public:
ArrayEqualsVisitor(const Array & right,const EqualOptions & opts)443 explicit ArrayEqualsVisitor(const Array& right, const EqualOptions& opts)
444 : RangeEqualsVisitor(right, 0, right.length(), 0), opts_(opts) {}
445
Visit(const NullArray & left)446 Status Visit(const NullArray& left) {
447 ARROW_UNUSED(left);
448 result_ = true;
449 return Status::OK();
450 }
451
Visit(const BooleanArray & left)452 Status Visit(const BooleanArray& left) {
453 const auto& right = checked_cast<const BooleanArray&>(right_);
454
455 if (left.null_count() > 0) {
456 const uint8_t* left_data = left.values()->data();
457 const uint8_t* right_data = right.values()->data();
458
459 for (int64_t i = 0; i < left.length(); ++i) {
460 if (left.IsValid(i) && BitUtil::GetBit(left_data, i + left.offset()) !=
461 BitUtil::GetBit(right_data, i + right.offset())) {
462 result_ = false;
463 return Status::OK();
464 }
465 }
466 result_ = true;
467 } else {
468 result_ = BitmapEquals(left.values()->data(), left.offset(), right.values()->data(),
469 right.offset(), left.length());
470 }
471 return Status::OK();
472 }
473
474 template <typename T>
475 typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value &&
476 !std::is_base_of<FloatArray, T>::value &&
477 !std::is_base_of<DoubleArray, T>::value &&
478 !std::is_base_of<BooleanArray, T>::value,
479 Status>::type
Visit(const T & left)480 Visit(const T& left) {
481 result_ = IsEqualPrimitive(left, checked_cast<const PrimitiveArray&>(right_));
482 return Status::OK();
483 }
484
485 // TODO nan-aware specialization for half-floats
486
Visit(const FloatArray & left)487 Status Visit(const FloatArray& left) {
488 result_ =
489 FloatingEquals<FloatType>(left, checked_cast<const FloatArray&>(right_), opts_);
490 return Status::OK();
491 }
492
Visit(const DoubleArray & left)493 Status Visit(const DoubleArray& left) {
494 result_ =
495 FloatingEquals<DoubleType>(left, checked_cast<const DoubleArray&>(right_), opts_);
496 return Status::OK();
497 }
498
499 template <typename ArrayType>
ValueOffsetsEqual(const ArrayType & left)500 bool ValueOffsetsEqual(const ArrayType& left) {
501 using offset_type = typename ArrayType::offset_type;
502
503 const auto& right = checked_cast<const ArrayType&>(right_);
504
505 if (left.offset() == 0 && right.offset() == 0) {
506 return left.value_offsets()->Equals(*right.value_offsets(),
507 (left.length() + 1) * sizeof(offset_type));
508 } else {
509 // One of the arrays is sliced; logic is more complicated because the
510 // value offsets are not both 0-based
511 auto left_offsets =
512 reinterpret_cast<const offset_type*>(left.value_offsets()->data()) +
513 left.offset();
514 auto right_offsets =
515 reinterpret_cast<const offset_type*>(right.value_offsets()->data()) +
516 right.offset();
517
518 for (int64_t i = 0; i < left.length() + 1; ++i) {
519 if (left_offsets[i] - left_offsets[0] != right_offsets[i] - right_offsets[0]) {
520 return false;
521 }
522 }
523 return true;
524 }
525 }
526
527 template <typename BinaryArrayType>
CompareBinary(const BinaryArrayType & left)528 bool CompareBinary(const BinaryArrayType& left) {
529 const auto& right = checked_cast<const BinaryArrayType&>(right_);
530
531 bool equal_offsets = ValueOffsetsEqual<BinaryArrayType>(left);
532 if (!equal_offsets) {
533 return false;
534 }
535
536 if (!left.value_data() && !(right.value_data())) {
537 return true;
538 }
539 if (left.value_offset(left.length()) == left.value_offset(0)) {
540 return true;
541 }
542
543 const uint8_t* left_data = left.value_data()->data();
544 const uint8_t* right_data = right.value_data()->data();
545
546 if (left.null_count() == 0) {
547 // Fast path for null count 0, single memcmp
548 if (left.offset() == 0 && right.offset() == 0) {
549 return std::memcmp(left_data, right_data,
550 left.raw_value_offsets()[left.length()]) == 0;
551 } else {
552 const int64_t total_bytes =
553 left.value_offset(left.length()) - left.value_offset(0);
554 return std::memcmp(left_data + left.value_offset(0),
555 right_data + right.value_offset(0),
556 static_cast<size_t>(total_bytes)) == 0;
557 }
558 } else {
559 // ARROW-537: Only compare data in non-null slots
560 auto left_offsets = left.raw_value_offsets();
561 auto right_offsets = right.raw_value_offsets();
562 for (int64_t i = 0; i < left.length(); ++i) {
563 if (left.IsNull(i)) {
564 continue;
565 }
566 if (std::memcmp(left_data + left_offsets[i], right_data + right_offsets[i],
567 left.value_length(i))) {
568 return false;
569 }
570 }
571 return true;
572 }
573 }
574
575 template <typename ListArrayType>
CompareList(const ListArrayType & left)576 bool CompareList(const ListArrayType& left) {
577 const auto& right = checked_cast<const ListArrayType&>(right_);
578
579 bool equal_offsets = ValueOffsetsEqual<ListArrayType>(left);
580 if (!equal_offsets) {
581 return false;
582 }
583
584 return left.values()->RangeEquals(left.value_offset(0),
585 left.value_offset(left.length()),
586 right.value_offset(0), right.values());
587 }
588
Visit(const BinaryArray & left)589 Status Visit(const BinaryArray& left) {
590 result_ = CompareBinary(left);
591 return Status::OK();
592 }
593
Visit(const LargeBinaryArray & left)594 Status Visit(const LargeBinaryArray& left) {
595 result_ = CompareBinary(left);
596 return Status::OK();
597 }
598
Visit(const ListArray & left)599 Status Visit(const ListArray& left) {
600 result_ = CompareList(left);
601 return Status::OK();
602 }
603
Visit(const LargeListArray & left)604 Status Visit(const LargeListArray& left) {
605 result_ = CompareList(left);
606 return Status::OK();
607 }
608
Visit(const FixedSizeListArray & left)609 Status Visit(const FixedSizeListArray& left) {
610 const auto& right = checked_cast<const FixedSizeListArray&>(right_);
611 result_ =
612 left.values()->RangeEquals(left.value_offset(0), left.value_offset(left.length()),
613 right.value_offset(0), right.values());
614 return Status::OK();
615 }
616
Visit(const DictionaryArray & left)617 Status Visit(const DictionaryArray& left) {
618 const auto& right = checked_cast<const DictionaryArray&>(right_);
619 if (!left.dictionary()->Equals(right.dictionary())) {
620 result_ = false;
621 } else {
622 result_ = left.indices()->Equals(right.indices());
623 }
624 return Status::OK();
625 }
626
627 template <typename T>
628 typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
629 Status>::type
Visit(const T & left)630 Visit(const T& left) {
631 return RangeEqualsVisitor::Visit(left);
632 }
633
Visit(const ExtensionArray & left)634 Status Visit(const ExtensionArray& left) {
635 result_ = (right_.type()->Equals(*left.type()) &&
636 ArrayEquals(*left.storage(),
637 *static_cast<const ExtensionArray&>(right_).storage()));
638 return Status::OK();
639 }
640
641 protected:
642 const EqualOptions opts_;
643 };
644
645 class ApproxEqualsVisitor : public ArrayEqualsVisitor {
646 public:
ApproxEqualsVisitor(const Array & right,const EqualOptions & opts)647 explicit ApproxEqualsVisitor(const Array& right, const EqualOptions& opts)
648 : ArrayEqualsVisitor(right, opts) {}
649
650 using ArrayEqualsVisitor::Visit;
651
652 // TODO half-floats
653
Visit(const FloatArray & left)654 Status Visit(const FloatArray& left) {
655 result_ = FloatingApproxEquals<FloatType>(
656 left, checked_cast<const FloatArray&>(right_), opts_);
657 return Status::OK();
658 }
659
Visit(const DoubleArray & left)660 Status Visit(const DoubleArray& left) {
661 result_ = FloatingApproxEquals<DoubleType>(
662 left, checked_cast<const DoubleArray&>(right_), opts_);
663 return Status::OK();
664 }
665 };
666
BaseDataEquals(const Array & left,const Array & right)667 static bool BaseDataEquals(const Array& left, const Array& right) {
668 if (left.length() != right.length() || left.null_count() != right.null_count() ||
669 left.type_id() != right.type_id()) {
670 return false;
671 }
672 // ARROW-2567: Ensure that not only the type id but also the type equality
673 // itself is checked.
674 if (!TypeEquals(*left.type(), *right.type(), false /* check_metadata */)) {
675 return false;
676 }
677 if (left.null_count() > 0 && left.null_count() < left.length()) {
678 return BitmapEquals(left.null_bitmap()->data(), left.offset(),
679 right.null_bitmap()->data(), right.offset(), left.length());
680 }
681 return true;
682 }
683
684 template <typename VISITOR, typename... Extra>
ArrayEqualsImpl(const Array & left,const Array & right,Extra &&...extra)685 inline bool ArrayEqualsImpl(const Array& left, const Array& right, Extra&&... extra) {
686 bool are_equal;
687 // The arrays are the same object
688 if (&left == &right) {
689 are_equal = true;
690 } else if (!BaseDataEquals(left, right)) {
691 are_equal = false;
692 } else if (left.length() == 0) {
693 are_equal = true;
694 } else if (left.null_count() == left.length()) {
695 are_equal = true;
696 } else {
697 VISITOR visitor(right, std::forward<Extra>(extra)...);
698 auto error = VisitArrayInline(left, &visitor);
699 if (!error.ok()) {
700 DCHECK(false) << "Arrays are not comparable: " << error.ToString();
701 }
702 are_equal = visitor.result();
703 }
704 return are_equal;
705 }
706
707 class TypeEqualsVisitor {
708 public:
TypeEqualsVisitor(const DataType & right,bool check_metadata)709 explicit TypeEqualsVisitor(const DataType& right, bool check_metadata)
710 : right_(right), check_metadata_(check_metadata), result_(false) {}
711
VisitChildren(const DataType & left)712 Status VisitChildren(const DataType& left) {
713 if (left.num_fields() != right_.num_fields()) {
714 result_ = false;
715 return Status::OK();
716 }
717
718 for (int i = 0; i < left.num_fields(); ++i) {
719 if (!left.field(i)->Equals(right_.field(i), check_metadata_)) {
720 result_ = false;
721 return Status::OK();
722 }
723 }
724 result_ = true;
725 return Status::OK();
726 }
727
728 template <typename T>
729 enable_if_t<is_null_type<T>::value || is_primitive_ctype<T>::value ||
730 is_base_binary_type<T>::value,
731 Status>
Visit(const T &)732 Visit(const T&) {
733 result_ = true;
734 return Status::OK();
735 }
736
737 template <typename T>
Visit(const T & left)738 enable_if_interval<T, Status> Visit(const T& left) {
739 const auto& right = checked_cast<const IntervalType&>(right_);
740 result_ = right.interval_type() == left.interval_type();
741 return Status::OK();
742 }
743
744 template <typename T>
745 enable_if_t<is_time_type<T>::value || is_date_type<T>::value ||
746 is_duration_type<T>::value,
747 Status>
Visit(const T & left)748 Visit(const T& left) {
749 const auto& right = checked_cast<const T&>(right_);
750 result_ = left.unit() == right.unit();
751 return Status::OK();
752 }
753
Visit(const TimestampType & left)754 Status Visit(const TimestampType& left) {
755 const auto& right = checked_cast<const TimestampType&>(right_);
756 result_ = left.unit() == right.unit() && left.timezone() == right.timezone();
757 return Status::OK();
758 }
759
Visit(const FixedSizeBinaryType & left)760 Status Visit(const FixedSizeBinaryType& left) {
761 const auto& right = checked_cast<const FixedSizeBinaryType&>(right_);
762 result_ = left.byte_width() == right.byte_width();
763 return Status::OK();
764 }
765
Visit(const Decimal128Type & left)766 Status Visit(const Decimal128Type& left) {
767 const auto& right = checked_cast<const Decimal128Type&>(right_);
768 result_ = left.precision() == right.precision() && left.scale() == right.scale();
769 return Status::OK();
770 }
771
772 template <typename T>
Visit(const T & left)773 enable_if_t<is_list_like_type<T>::value || is_struct_type<T>::value, Status> Visit(
774 const T& left) {
775 return VisitChildren(left);
776 }
777
Visit(const MapType & left)778 Status Visit(const MapType& left) {
779 const auto& right = checked_cast<const MapType&>(right_);
780 if (left.keys_sorted() != right.keys_sorted()) {
781 result_ = false;
782 return Status::OK();
783 }
784 return VisitChildren(left);
785 }
786
Visit(const UnionType & left)787 Status Visit(const UnionType& left) {
788 const auto& right = checked_cast<const UnionType&>(right_);
789
790 if (left.mode() != right.mode() || left.type_codes() != right.type_codes()) {
791 result_ = false;
792 return Status::OK();
793 }
794
795 result_ = std::equal(
796 left.fields().begin(), left.fields().end(), right.fields().begin(),
797 [this](const std::shared_ptr<Field>& l, const std::shared_ptr<Field>& r) {
798 return l->Equals(r, check_metadata_);
799 });
800 return Status::OK();
801 }
802
Visit(const DictionaryType & left)803 Status Visit(const DictionaryType& left) {
804 const auto& right = checked_cast<const DictionaryType&>(right_);
805 result_ = left.index_type()->Equals(right.index_type()) &&
806 left.value_type()->Equals(right.value_type()) &&
807 (left.ordered() == right.ordered());
808 return Status::OK();
809 }
810
Visit(const ExtensionType & left)811 Status Visit(const ExtensionType& left) {
812 result_ = left.ExtensionEquals(static_cast<const ExtensionType&>(right_));
813 return Status::OK();
814 }
815
result() const816 bool result() const { return result_; }
817
818 protected:
819 const DataType& right_;
820 bool check_metadata_;
821 bool result_;
822 };
823
824 class ScalarEqualsVisitor {
825 public:
ScalarEqualsVisitor(const Scalar & right)826 explicit ScalarEqualsVisitor(const Scalar& right) : right_(right), result_(false) {}
827
Visit(const NullScalar & left)828 Status Visit(const NullScalar& left) {
829 result_ = true;
830 return Status::OK();
831 }
832
Visit(const BooleanScalar & left)833 Status Visit(const BooleanScalar& left) {
834 const auto& right = checked_cast<const BooleanScalar&>(right_);
835 result_ = left.value == right.value;
836 return Status::OK();
837 }
838
839 template <typename T>
840 typename std::enable_if<
841 std::is_base_of<internal::PrimitiveScalar<typename T::TypeClass>, T>::value ||
842 std::is_base_of<TemporalScalar<typename T::TypeClass>, T>::value,
843 Status>::type
Visit(const T & left_)844 Visit(const T& left_) {
845 const auto& right = checked_cast<const T&>(right_);
846 result_ = right.value == left_.value;
847 return Status::OK();
848 }
849
850 template <typename T>
851 typename std::enable_if<std::is_base_of<BaseBinaryScalar, T>::value, Status>::type
Visit(const T & left)852 Visit(const T& left) {
853 const auto& right = checked_cast<const BaseBinaryScalar&>(right_);
854 result_ = internal::SharedPtrEquals(left.value, right.value);
855 return Status::OK();
856 }
857
Visit(const Decimal128Scalar & left)858 Status Visit(const Decimal128Scalar& left) {
859 const auto& right = checked_cast<const Decimal128Scalar&>(right_);
860 result_ = left.value == right.value;
861 return Status::OK();
862 }
863
Visit(const ListScalar & left)864 Status Visit(const ListScalar& left) {
865 const auto& right = checked_cast<const ListScalar&>(right_);
866 result_ = internal::SharedPtrEquals(left.value, right.value);
867 return Status::OK();
868 }
869
Visit(const LargeListScalar & left)870 Status Visit(const LargeListScalar& left) {
871 const auto& right = checked_cast<const LargeListScalar&>(right_);
872 result_ = internal::SharedPtrEquals(left.value, right.value);
873 return Status::OK();
874 }
875
Visit(const MapScalar & left)876 Status Visit(const MapScalar& left) {
877 const auto& right = checked_cast<const MapScalar&>(right_);
878 result_ = internal::SharedPtrEquals(left.value, right.value);
879 return Status::OK();
880 }
881
Visit(const FixedSizeListScalar & left)882 Status Visit(const FixedSizeListScalar& left) {
883 const auto& right = checked_cast<const FixedSizeListScalar&>(right_);
884 result_ = internal::SharedPtrEquals(left.value, right.value);
885 return Status::OK();
886 }
887
Visit(const StructScalar & left)888 Status Visit(const StructScalar& left) {
889 const auto& right = checked_cast<const StructScalar&>(right_);
890
891 if (right.value.size() != left.value.size()) {
892 result_ = false;
893 } else {
894 bool all_equals = true;
895 for (size_t i = 0; i < left.value.size() && all_equals; i++) {
896 all_equals &= internal::SharedPtrEquals(left.value[i], right.value[i]);
897 }
898 result_ = all_equals;
899 }
900
901 return Status::OK();
902 }
903
Visit(const UnionScalar & left)904 Status Visit(const UnionScalar& left) { return Status::NotImplemented("union"); }
905
Visit(const DictionaryScalar & left)906 Status Visit(const DictionaryScalar& left) {
907 return Status::NotImplemented("dictionary");
908 }
909
Visit(const ExtensionScalar & left)910 Status Visit(const ExtensionScalar& left) {
911 return Status::NotImplemented("extension");
912 }
913
result() const914 bool result() const { return result_; }
915
916 protected:
917 const Scalar& right_;
918 bool result_;
919 };
920
PrintDiff(const Array & left,const Array & right,std::ostream * os)921 Status PrintDiff(const Array& left, const Array& right, std::ostream* os) {
922 if (os == nullptr) {
923 return Status::OK();
924 }
925
926 if (!left.type()->Equals(right.type())) {
927 *os << "# Array types differed: " << *left.type() << " vs " << *right.type()
928 << std::endl;
929 return Status::OK();
930 }
931
932 if (left.type()->id() == Type::DICTIONARY) {
933 *os << "# Dictionary arrays differed" << std::endl;
934
935 const auto& left_dict = checked_cast<const DictionaryArray&>(left);
936 const auto& right_dict = checked_cast<const DictionaryArray&>(right);
937
938 *os << "## dictionary diff";
939 auto pos = os->tellp();
940 RETURN_NOT_OK(PrintDiff(*left_dict.dictionary(), *right_dict.dictionary(), os));
941 if (os->tellp() == pos) {
942 *os << std::endl;
943 }
944
945 *os << "## indices diff";
946 pos = os->tellp();
947 RETURN_NOT_OK(PrintDiff(*left_dict.indices(), *right_dict.indices(), os));
948 if (os->tellp() == pos) {
949 *os << std::endl;
950 }
951 return Status::OK();
952 }
953
954 ARROW_ASSIGN_OR_RAISE(auto edits, Diff(left, right, default_memory_pool()));
955 ARROW_ASSIGN_OR_RAISE(auto formatter, MakeUnifiedDiffFormatter(*left.type(), os));
956 return formatter(*edits, left, right);
957 }
958
959 } // namespace
960
ArrayEquals(const Array & left,const Array & right,const EqualOptions & opts)961 bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) {
962 bool are_equal = ArrayEqualsImpl<ArrayEqualsVisitor>(left, right, opts);
963 if (!are_equal) {
964 ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink()));
965 }
966 return are_equal;
967 }
968
ArrayApproxEquals(const Array & left,const Array & right,const EqualOptions & opts)969 bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) {
970 bool are_equal = ArrayEqualsImpl<ApproxEqualsVisitor>(left, right, opts);
971 if (!are_equal) {
972 DCHECK_OK(PrintDiff(left, right, opts.diff_sink()));
973 }
974 return are_equal;
975 }
976
ArrayRangeEquals(const Array & left,const Array & right,int64_t left_start_idx,int64_t left_end_idx,int64_t right_start_idx)977 bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
978 int64_t left_end_idx, int64_t right_start_idx) {
979 bool are_equal;
980 if (&left == &right) {
981 are_equal = true;
982 } else if (left.type_id() != right.type_id()) {
983 are_equal = false;
984 } else if (left.length() == 0) {
985 are_equal = true;
986 } else {
987 RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx, right_start_idx);
988 auto error = VisitArrayInline(left, &visitor);
989 if (!error.ok()) {
990 DCHECK(false) << "Arrays are not comparable: " << error.ToString();
991 }
992 are_equal = visitor.result();
993 }
994 return are_equal;
995 }
996
997 namespace {
998
StridedIntegerTensorContentEquals(const int dim_index,int64_t left_offset,int64_t right_offset,int elem_size,const Tensor & left,const Tensor & right)999 bool StridedIntegerTensorContentEquals(const int dim_index, int64_t left_offset,
1000 int64_t right_offset, int elem_size,
1001 const Tensor& left, const Tensor& right) {
1002 const auto n = left.shape()[dim_index];
1003 const auto left_stride = left.strides()[dim_index];
1004 const auto right_stride = right.strides()[dim_index];
1005 if (dim_index == left.ndim() - 1) {
1006 for (int64_t i = 0; i < n; ++i) {
1007 if (memcmp(left.raw_data() + left_offset + i * left_stride,
1008 right.raw_data() + right_offset + i * right_stride, elem_size) != 0) {
1009 return false;
1010 }
1011 }
1012 return true;
1013 }
1014 for (int64_t i = 0; i < n; ++i) {
1015 if (!StridedIntegerTensorContentEquals(dim_index + 1, left_offset, right_offset,
1016 elem_size, left, right)) {
1017 return false;
1018 }
1019 left_offset += left_stride;
1020 right_offset += right_stride;
1021 }
1022 return true;
1023 }
1024
IntegerTensorEquals(const Tensor & left,const Tensor & right)1025 bool IntegerTensorEquals(const Tensor& left, const Tensor& right) {
1026 bool are_equal;
1027 // The arrays are the same object
1028 if (&left == &right) {
1029 are_equal = true;
1030 } else {
1031 const bool left_row_major_p = left.is_row_major();
1032 const bool left_column_major_p = left.is_column_major();
1033 const bool right_row_major_p = right.is_row_major();
1034 const bool right_column_major_p = right.is_column_major();
1035
1036 if (!(left_row_major_p && right_row_major_p) &&
1037 !(left_column_major_p && right_column_major_p)) {
1038 const auto& type = checked_cast<const FixedWidthType&>(*left.type());
1039 are_equal =
1040 StridedIntegerTensorContentEquals(0, 0, 0, type.bit_width() / 8, left, right);
1041 } else {
1042 const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
1043 const int byte_width = size_meta.bit_width() / CHAR_BIT;
1044 DCHECK_GT(byte_width, 0);
1045
1046 const uint8_t* left_data = left.data()->data();
1047 const uint8_t* right_data = right.data()->data();
1048
1049 are_equal = memcmp(left_data, right_data,
1050 static_cast<size_t>(byte_width * left.size())) == 0;
1051 }
1052 }
1053 return are_equal;
1054 }
1055
1056 template <typename DataType>
StridedFloatTensorContentEquals(const int dim_index,int64_t left_offset,int64_t right_offset,const Tensor & left,const Tensor & right,const EqualOptions & opts)1057 bool StridedFloatTensorContentEquals(const int dim_index, int64_t left_offset,
1058 int64_t right_offset, const Tensor& left,
1059 const Tensor& right, const EqualOptions& opts) {
1060 using c_type = typename DataType::c_type;
1061 static_assert(std::is_floating_point<c_type>::value,
1062 "DataType must be a floating point type");
1063
1064 const auto n = left.shape()[dim_index];
1065 const auto left_stride = left.strides()[dim_index];
1066 const auto right_stride = right.strides()[dim_index];
1067 if (dim_index == left.ndim() - 1) {
1068 auto left_data = left.raw_data();
1069 auto right_data = right.raw_data();
1070 if (opts.nans_equal()) {
1071 for (int64_t i = 0; i < n; ++i) {
1072 c_type left_value =
1073 *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride);
1074 c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset +
1075 i * right_stride);
1076 if (left_value != right_value &&
1077 !(std::isnan(left_value) && std::isnan(right_value))) {
1078 return false;
1079 }
1080 }
1081 } else {
1082 for (int64_t i = 0; i < n; ++i) {
1083 c_type left_value =
1084 *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride);
1085 c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset +
1086 i * right_stride);
1087 if (left_value != right_value) {
1088 return false;
1089 }
1090 }
1091 }
1092 return true;
1093 }
1094 for (int64_t i = 0; i < n; ++i) {
1095 if (!StridedFloatTensorContentEquals<DataType>(dim_index + 1, left_offset,
1096 right_offset, left, right, opts)) {
1097 return false;
1098 }
1099 left_offset += left_stride;
1100 right_offset += right_stride;
1101 }
1102 return true;
1103 }
1104
1105 template <typename DataType>
FloatTensorEquals(const Tensor & left,const Tensor & right,const EqualOptions & opts)1106 bool FloatTensorEquals(const Tensor& left, const Tensor& right,
1107 const EqualOptions& opts) {
1108 return StridedFloatTensorContentEquals<DataType>(0, 0, 0, left, right, opts);
1109 }
1110
1111 } // namespace
1112
TensorEquals(const Tensor & left,const Tensor & right,const EqualOptions & opts)1113 bool TensorEquals(const Tensor& left, const Tensor& right, const EqualOptions& opts) {
1114 if (left.type_id() != right.type_id()) {
1115 return false;
1116 } else if (left.size() == 0 && right.size() == 0) {
1117 return true;
1118 } else if (left.shape() != right.shape()) {
1119 return false;
1120 }
1121
1122 switch (left.type_id()) {
1123 // TODO: Support half-float tensors
1124 // case Type::HALF_FLOAT:
1125 case Type::FLOAT:
1126 return FloatTensorEquals<FloatType>(left, right, opts);
1127
1128 case Type::DOUBLE:
1129 return FloatTensorEquals<DoubleType>(left, right, opts);
1130
1131 default:
1132 return IntegerTensorEquals(left, right);
1133 }
1134 }
1135
1136 namespace {
1137
1138 template <typename LeftSparseIndexType, typename RightSparseIndexType>
1139 struct SparseTensorEqualsImpl {
Comparearrow::__anon6163ced00811::SparseTensorEqualsImpl1140 static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
1141 const SparseTensorImpl<RightSparseIndexType>& right,
1142 const EqualOptions&) {
1143 // TODO(mrkn): should we support the equality among different formats?
1144 return false;
1145 }
1146 };
1147
IntegerSparseTensorDataEquals(const uint8_t * left_data,const uint8_t * right_data,const int byte_width,const int64_t length)1148 bool IntegerSparseTensorDataEquals(const uint8_t* left_data, const uint8_t* right_data,
1149 const int byte_width, const int64_t length) {
1150 if (left_data == right_data) {
1151 return true;
1152 }
1153 return memcmp(left_data, right_data, static_cast<size_t>(byte_width * length)) == 0;
1154 }
1155
1156 template <typename DataType>
FloatSparseTensorDataEquals(const typename DataType::c_type * left_data,const typename DataType::c_type * right_data,const int64_t length,const EqualOptions & opts)1157 bool FloatSparseTensorDataEquals(const typename DataType::c_type* left_data,
1158 const typename DataType::c_type* right_data,
1159 const int64_t length, const EqualOptions& opts) {
1160 using c_type = typename DataType::c_type;
1161 static_assert(std::is_floating_point<c_type>::value,
1162 "DataType must be a floating point type");
1163 if (opts.nans_equal()) {
1164 if (left_data == right_data) {
1165 return true;
1166 }
1167
1168 for (int64_t i = 0; i < length; ++i) {
1169 const auto left = left_data[i];
1170 const auto right = right_data[i];
1171 if (left != right && !(std::isnan(left) && std::isnan(right))) {
1172 return false;
1173 }
1174 }
1175 } else {
1176 for (int64_t i = 0; i < length; ++i) {
1177 if (left_data[i] != right_data[i]) {
1178 return false;
1179 }
1180 }
1181 }
1182 return true;
1183 }
1184
1185 template <typename SparseIndexType>
1186 struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
Comparearrow::__anon6163ced00811::SparseTensorEqualsImpl1187 static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
1188 const SparseTensorImpl<SparseIndexType>& right,
1189 const EqualOptions& opts) {
1190 DCHECK(left.type()->id() == right.type()->id());
1191 DCHECK(left.shape() == right.shape());
1192
1193 const auto length = left.non_zero_length();
1194 DCHECK(length == right.non_zero_length());
1195
1196 const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index());
1197 const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index());
1198
1199 if (!left_index.Equals(right_index)) {
1200 return false;
1201 }
1202
1203 const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
1204 const int byte_width = size_meta.bit_width() / CHAR_BIT;
1205 DCHECK_GT(byte_width, 0);
1206
1207 const uint8_t* left_data = left.data()->data();
1208 const uint8_t* right_data = right.data()->data();
1209 switch (left.type()->id()) {
1210 // TODO: Support half-float tensors
1211 // case Type::HALF_FLOAT:
1212 case Type::FLOAT:
1213 return FloatSparseTensorDataEquals<FloatType>(
1214 reinterpret_cast<const float*>(left_data),
1215 reinterpret_cast<const float*>(right_data), length, opts);
1216
1217 case Type::DOUBLE:
1218 return FloatSparseTensorDataEquals<DoubleType>(
1219 reinterpret_cast<const double*>(left_data),
1220 reinterpret_cast<const double*>(right_data), length, opts);
1221
1222 default: // Integer cases
1223 return IntegerSparseTensorDataEquals(left_data, right_data, byte_width, length);
1224 }
1225 }
1226 };
1227
1228 template <typename SparseIndexType>
SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType> & left,const SparseTensor & right,const EqualOptions & opts)1229 inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
1230 const SparseTensor& right,
1231 const EqualOptions& opts) {
1232 switch (right.format_id()) {
1233 case SparseTensorFormat::COO: {
1234 const auto& right_coo =
1235 checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
1236 return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(
1237 left, right_coo, opts);
1238 }
1239
1240 case SparseTensorFormat::CSR: {
1241 const auto& right_csr =
1242 checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
1243 return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(
1244 left, right_csr, opts);
1245 }
1246
1247 case SparseTensorFormat::CSC: {
1248 const auto& right_csc =
1249 checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(right);
1250 return SparseTensorEqualsImpl<SparseIndexType, SparseCSCIndex>::Compare(
1251 left, right_csc, opts);
1252 }
1253
1254 case SparseTensorFormat::CSF: {
1255 const auto& right_csf =
1256 checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(right);
1257 return SparseTensorEqualsImpl<SparseIndexType, SparseCSFIndex>::Compare(
1258 left, right_csf, opts);
1259 }
1260
1261 default:
1262 return false;
1263 }
1264 }
1265
1266 } // namespace
1267
SparseTensorEquals(const SparseTensor & left,const SparseTensor & right,const EqualOptions & opts)1268 bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right,
1269 const EqualOptions& opts) {
1270 if (left.type()->id() != right.type()->id()) {
1271 return false;
1272 } else if (left.size() == 0 && right.size() == 0) {
1273 return true;
1274 } else if (left.shape() != right.shape()) {
1275 return false;
1276 } else if (left.non_zero_length() != right.non_zero_length()) {
1277 return false;
1278 }
1279
1280 switch (left.format_id()) {
1281 case SparseTensorFormat::COO: {
1282 const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
1283 return SparseTensorEqualsImplDispatch(left_coo, right, opts);
1284 }
1285
1286 case SparseTensorFormat::CSR: {
1287 const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
1288 return SparseTensorEqualsImplDispatch(left_csr, right, opts);
1289 }
1290
1291 case SparseTensorFormat::CSC: {
1292 const auto& left_csc = checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(left);
1293 return SparseTensorEqualsImplDispatch(left_csc, right, opts);
1294 }
1295
1296 case SparseTensorFormat::CSF: {
1297 const auto& left_csf = checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(left);
1298 return SparseTensorEqualsImplDispatch(left_csf, right, opts);
1299 }
1300
1301 default:
1302 return false;
1303 }
1304 }
1305
TypeEquals(const DataType & left,const DataType & right,bool check_metadata)1306 bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) {
1307 // The arrays are the same object
1308 if (&left == &right) {
1309 return true;
1310 } else if (left.id() != right.id()) {
1311 return false;
1312 } else {
1313 // First try to compute fingerprints
1314 if (check_metadata) {
1315 const auto& left_metadata_fp = left.metadata_fingerprint();
1316 const auto& right_metadata_fp = right.metadata_fingerprint();
1317 if (left_metadata_fp != right_metadata_fp) {
1318 return false;
1319 }
1320 }
1321
1322 const auto& left_fp = left.fingerprint();
1323 const auto& right_fp = right.fingerprint();
1324 if (!left_fp.empty() && !right_fp.empty()) {
1325 return left_fp == right_fp;
1326 }
1327
1328 // TODO remove check_metadata here?
1329 TypeEqualsVisitor visitor(right, check_metadata);
1330 auto error = VisitTypeInline(left, &visitor);
1331 if (!error.ok()) {
1332 DCHECK(false) << "Types are not comparable: " << error.ToString();
1333 }
1334 return visitor.result();
1335 }
1336 }
1337
ScalarEquals(const Scalar & left,const Scalar & right)1338 bool ScalarEquals(const Scalar& left, const Scalar& right) {
1339 bool are_equal = false;
1340 if (&left == &right) {
1341 are_equal = true;
1342 } else if (!left.type->Equals(right.type)) {
1343 are_equal = false;
1344 } else if (left.is_valid != right.is_valid) {
1345 are_equal = false;
1346 } else {
1347 ScalarEqualsVisitor visitor(right);
1348 auto error = VisitScalarInline(left, &visitor);
1349 DCHECK_OK(error);
1350 are_equal = visitor.result();
1351 }
1352 return are_equal;
1353 }
1354
1355 } // namespace arrow
1356