1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 // Functions for comparing Arrow data structures
19 
20 #include "arrow/compare.h"
21 
22 #include <climits>
23 #include <cmath>
24 #include <cstdint>
25 #include <cstring>
26 #include <memory>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 #include <vector>
31 
32 #include "arrow/array.h"
33 #include "arrow/array/diff.h"
34 #include "arrow/buffer.h"
35 #include "arrow/scalar.h"
36 #include "arrow/sparse_tensor.h"
37 #include "arrow/status.h"
38 #include "arrow/tensor.h"
39 #include "arrow/type.h"
40 #include "arrow/type_traits.h"
41 #include "arrow/util/bit_util.h"
42 #include "arrow/util/checked_cast.h"
43 #include "arrow/util/logging.h"
44 #include "arrow/util/macros.h"
45 #include "arrow/util/memory.h"
46 #include "arrow/visitor_inline.h"
47 
48 namespace arrow {
49 
50 using internal::BitmapEquals;
51 using internal::checked_cast;
52 
53 // ----------------------------------------------------------------------
54 // Public method implementations
55 
56 namespace {
57 
58 // These helper functions assume we already checked the arrays have equal
59 // sizes and null bitmaps.
60 
61 template <typename ArrowType, typename EqualityFunc>
BaseFloatingEquals(const NumericArray<ArrowType> & left,const NumericArray<ArrowType> & right,EqualityFunc && equals)62 inline bool BaseFloatingEquals(const NumericArray<ArrowType>& left,
63                                const NumericArray<ArrowType>& right,
64                                EqualityFunc&& equals) {
65   using T = typename ArrowType::c_type;
66 
67   const T* left_data = left.raw_values();
68   const T* right_data = right.raw_values();
69 
70   if (left.null_count() > 0) {
71     for (int64_t i = 0; i < left.length(); ++i) {
72       if (left.IsNull(i)) continue;
73       if (!equals(left_data[i], right_data[i])) {
74         return false;
75       }
76     }
77   } else {
78     for (int64_t i = 0; i < left.length(); ++i) {
79       if (!equals(left_data[i], right_data[i])) {
80         return false;
81       }
82     }
83   }
84   return true;
85 }
86 
87 template <typename ArrowType>
FloatingEquals(const NumericArray<ArrowType> & left,const NumericArray<ArrowType> & right,const EqualOptions & opts)88 inline bool FloatingEquals(const NumericArray<ArrowType>& left,
89                            const NumericArray<ArrowType>& right,
90                            const EqualOptions& opts) {
91   using T = typename ArrowType::c_type;
92 
93   if (opts.nans_equal()) {
94     return BaseFloatingEquals<ArrowType>(left, right, [](T x, T y) -> bool {
95       return (x == y) || (std::isnan(x) && std::isnan(y));
96     });
97   } else {
98     return BaseFloatingEquals<ArrowType>(left, right,
99                                          [](T x, T y) -> bool { return x == y; });
100   }
101 }
102 
103 template <typename ArrowType>
FloatingApproxEquals(const NumericArray<ArrowType> & left,const NumericArray<ArrowType> & right,const EqualOptions & opts)104 inline bool FloatingApproxEquals(const NumericArray<ArrowType>& left,
105                                  const NumericArray<ArrowType>& right,
106                                  const EqualOptions& opts) {
107   using T = typename ArrowType::c_type;
108   const T epsilon = static_cast<T>(opts.atol());
109 
110   if (opts.nans_equal()) {
111     return BaseFloatingEquals<ArrowType>(left, right, [epsilon](T x, T y) -> bool {
112       return (fabs(x - y) <= epsilon) || (std::isnan(x) && std::isnan(y));
113     });
114   } else {
115     return BaseFloatingEquals<ArrowType>(
116         left, right, [epsilon](T x, T y) -> bool { return fabs(x - y) <= epsilon; });
117   }
118 }
119 
120 // RangeEqualsVisitor assumes the range sizes are equal
121 
122 class RangeEqualsVisitor {
123  public:
RangeEqualsVisitor(const Array & right,int64_t left_start_idx,int64_t left_end_idx,int64_t right_start_idx)124   RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx,
125                      int64_t right_start_idx)
126       : right_(right),
127         left_start_idx_(left_start_idx),
128         left_end_idx_(left_end_idx),
129         right_start_idx_(right_start_idx),
130         result_(false) {}
131 
132   template <typename ArrayType>
CompareValues(const ArrayType & left)133   inline Status CompareValues(const ArrayType& left) {
134     const auto& right = checked_cast<const ArrayType&>(right_);
135 
136     for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
137          ++i, ++o_i) {
138       const bool is_null = left.IsNull(i);
139       if (is_null != right.IsNull(o_i) ||
140           (!is_null && left.Value(i) != right.Value(o_i))) {
141         result_ = false;
142         return Status::OK();
143       }
144     }
145     result_ = true;
146     return Status::OK();
147   }
148 
149   template <typename BinaryArrayType>
CompareBinaryRange(const BinaryArrayType & left) const150   bool CompareBinaryRange(const BinaryArrayType& left) const {
151     const auto& right = checked_cast<const BinaryArrayType&>(right_);
152 
153     for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
154          ++i, ++o_i) {
155       const bool is_null = left.IsNull(i);
156       if (is_null != right.IsNull(o_i)) {
157         return false;
158       }
159       if (is_null) continue;
160       const auto begin_offset = left.value_offset(i);
161       const auto end_offset = left.value_offset(i + 1);
162       const auto right_begin_offset = right.value_offset(o_i);
163       const auto right_end_offset = right.value_offset(o_i + 1);
164       // Underlying can't be equal if the size isn't equal
165       if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
166         return false;
167       }
168 
169       if (end_offset - begin_offset > 0 &&
170           std::memcmp(left.value_data()->data() + begin_offset,
171                       right.value_data()->data() + right_begin_offset,
172                       static_cast<size_t>(end_offset - begin_offset))) {
173         return false;
174       }
175     }
176     return true;
177   }
178 
179   template <typename ListArrayType>
CompareLists(const ListArrayType & left)180   bool CompareLists(const ListArrayType& left) {
181     const auto& right = checked_cast<const ListArrayType&>(right_);
182 
183     const std::shared_ptr<Array>& left_values = left.values();
184     const std::shared_ptr<Array>& right_values = right.values();
185 
186     for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
187          ++i, ++o_i) {
188       const bool is_null = left.IsNull(i);
189       if (is_null != right.IsNull(o_i)) {
190         return false;
191       }
192       if (is_null) continue;
193       const auto begin_offset = left.value_offset(i);
194       const auto end_offset = left.value_offset(i + 1);
195       const auto right_begin_offset = right.value_offset(o_i);
196       const auto right_end_offset = right.value_offset(o_i + 1);
197       // Underlying can't be equal if the size isn't equal
198       if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
199         return false;
200       }
201       if (!left_values->RangeEquals(begin_offset, end_offset, right_begin_offset,
202                                     right_values)) {
203         return false;
204       }
205     }
206     return true;
207   }
208 
CompareStructs(const StructArray & left)209   bool CompareStructs(const StructArray& left) {
210     const auto& right = checked_cast<const StructArray&>(right_);
211     bool equal_fields = true;
212     for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
213          ++i, ++o_i) {
214       if (left.IsNull(i) != right.IsNull(o_i)) {
215         return false;
216       }
217       if (left.IsNull(i)) continue;
218       for (int j = 0; j < left.num_fields(); ++j) {
219         // TODO: really we should be comparing stretches of non-null data rather
220         // than looking at one value at a time.
221         equal_fields = left.field(j)->RangeEquals(i, i + 1, o_i, right.field(j));
222         if (!equal_fields) {
223           return false;
224         }
225       }
226     }
227     return true;
228   }
229 
CompareUnions(const UnionArray & left) const230   bool CompareUnions(const UnionArray& left) const {
231     const auto& right = checked_cast<const UnionArray&>(right_);
232 
233     const UnionMode::type union_mode = left.mode();
234     if (union_mode != right.mode()) {
235       return false;
236     }
237 
238     const auto& left_type = checked_cast<const UnionType&>(*left.type());
239 
240     const std::vector<int>& child_ids = left_type.child_ids();
241 
242     const int8_t* left_codes = left.raw_type_codes();
243     const int8_t* right_codes = right.raw_type_codes();
244 
245     for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
246          ++i, ++o_i) {
247       if (left.IsNull(i) != right.IsNull(o_i)) {
248         return false;
249       }
250       if (left.IsNull(i)) continue;
251       if (left_codes[i] != right_codes[o_i]) {
252         return false;
253       }
254 
255       auto child_num = child_ids[left_codes[i]];
256 
257       // TODO(wesm): really we should be comparing stretches of non-null data
258       // rather than looking at one value at a time.
259       if (union_mode == UnionMode::SPARSE) {
260         if (!left.field(child_num)->RangeEquals(i, i + 1, o_i, right.field(child_num))) {
261           return false;
262         }
263       } else {
264         const int32_t offset = left.raw_value_offsets()[i];
265         const int32_t o_offset = right.raw_value_offsets()[o_i];
266         if (!left.field(child_num)->RangeEquals(offset, offset + 1, o_offset,
267                                                 right.field(child_num))) {
268           return false;
269         }
270       }
271     }
272     return true;
273   }
274 
Visit(const BinaryArray & left)275   Status Visit(const BinaryArray& left) {
276     result_ = CompareBinaryRange(left);
277     return Status::OK();
278   }
279 
Visit(const LargeBinaryArray & left)280   Status Visit(const LargeBinaryArray& left) {
281     result_ = CompareBinaryRange(left);
282     return Status::OK();
283   }
284 
Visit(const FixedSizeBinaryArray & left)285   Status Visit(const FixedSizeBinaryArray& left) {
286     const auto& right = checked_cast<const FixedSizeBinaryArray&>(right_);
287 
288     int32_t width = left.byte_width();
289 
290     const uint8_t* left_data = nullptr;
291     const uint8_t* right_data = nullptr;
292 
293     if (left.values()) {
294       left_data = left.raw_values();
295     }
296 
297     if (right.values()) {
298       right_data = right.raw_values();
299     }
300 
301     for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
302          ++i, ++o_i) {
303       const bool is_null = left.IsNull(i);
304       if (is_null != right.IsNull(o_i)) {
305         result_ = false;
306         return Status::OK();
307       }
308       if (is_null) continue;
309 
310       if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
311         result_ = false;
312         return Status::OK();
313       }
314     }
315     result_ = true;
316     return Status::OK();
317   }
318 
Visit(const Decimal128Array & left)319   Status Visit(const Decimal128Array& left) {
320     return Visit(checked_cast<const FixedSizeBinaryArray&>(left));
321   }
322 
Visit(const NullArray & left)323   Status Visit(const NullArray& left) {
324     ARROW_UNUSED(left);
325     result_ = true;
326     return Status::OK();
327   }
328 
329   template <typename T>
Visit(const T & left)330   typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit(
331       const T& left) {
332     return CompareValues<T>(left);
333   }
334 
Visit(const ListArray & left)335   Status Visit(const ListArray& left) {
336     result_ = CompareLists(left);
337     return Status::OK();
338   }
339 
Visit(const LargeListArray & left)340   Status Visit(const LargeListArray& left) {
341     result_ = CompareLists(left);
342     return Status::OK();
343   }
344 
Visit(const FixedSizeListArray & left)345   Status Visit(const FixedSizeListArray& left) {
346     const auto& right = checked_cast<const FixedSizeListArray&>(right_);
347     result_ = left.values()->RangeEquals(
348         left.value_offset(left_start_idx_), left.value_offset(left_end_idx_),
349         right.value_offset(right_start_idx_), right.values());
350     return Status::OK();
351   }
352 
Visit(const StructArray & left)353   Status Visit(const StructArray& left) {
354     result_ = CompareStructs(left);
355     return Status::OK();
356   }
357 
Visit(const UnionArray & left)358   Status Visit(const UnionArray& left) {
359     result_ = CompareUnions(left);
360     return Status::OK();
361   }
362 
Visit(const DictionaryArray & left)363   Status Visit(const DictionaryArray& left) {
364     const auto& right = checked_cast<const DictionaryArray&>(right_);
365     if (!left.dictionary()->Equals(right.dictionary())) {
366       result_ = false;
367       return Status::OK();
368     }
369     result_ = left.indices()->RangeEquals(left_start_idx_, left_end_idx_,
370                                           right_start_idx_, right.indices());
371     return Status::OK();
372   }
373 
Visit(const ExtensionArray & left)374   Status Visit(const ExtensionArray& left) {
375     result_ = (right_.type()->Equals(*left.type()) &&
376                ArrayRangeEquals(*left.storage(),
377                                 *static_cast<const ExtensionArray&>(right_).storage(),
378                                 left_start_idx_, left_end_idx_, right_start_idx_));
379     return Status::OK();
380   }
381 
result() const382   bool result() const { return result_; }
383 
384  protected:
385   const Array& right_;
386   int64_t left_start_idx_;
387   int64_t left_end_idx_;
388   int64_t right_start_idx_;
389 
390   bool result_;
391 };
392 
IsEqualPrimitive(const PrimitiveArray & left,const PrimitiveArray & right)393 static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& right) {
394   const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
395   const int byte_width = size_meta.bit_width() / CHAR_BIT;
396 
397   const uint8_t* left_data = nullptr;
398   const uint8_t* right_data = nullptr;
399 
400   if (left.values()) {
401     left_data = left.values()->data() + left.offset() * byte_width;
402   }
403 
404   if (right.values()) {
405     right_data = right.values()->data() + right.offset() * byte_width;
406   }
407 
408   if (byte_width == 0) {
409     // Special case 0-width data, as the data pointers may be null
410     for (int64_t i = 0; i < left.length(); ++i) {
411       if (left.IsNull(i) != right.IsNull(i)) {
412         return false;
413       }
414     }
415     return true;
416   } else if (left.null_count() > 0) {
417     for (int64_t i = 0; i < left.length(); ++i) {
418       const bool left_null = left.IsNull(i);
419       const bool right_null = right.IsNull(i);
420       if (left_null != right_null) {
421         return false;
422       }
423       if (!left_null && memcmp(left_data, right_data, byte_width) != 0) {
424         return false;
425       }
426       left_data += byte_width;
427       right_data += byte_width;
428     }
429     return true;
430   } else {
431     auto number_of_bytes_to_compare = static_cast<size_t>(byte_width * left.length());
432     return memcmp(left_data, right_data, number_of_bytes_to_compare) == 0;
433   }
434 }
435 
436 // A bit confusing: ArrayEqualsVisitor inherits from RangeEqualsVisitor but
437 // doesn't share the same preconditions.
438 // When RangeEqualsVisitor is called, we only know the range sizes equal.
439 // When ArrayEqualsVisitor is called, we know the sizes and null bitmaps are equal.
440 
441 class ArrayEqualsVisitor : public RangeEqualsVisitor {
442  public:
ArrayEqualsVisitor(const Array & right,const EqualOptions & opts)443   explicit ArrayEqualsVisitor(const Array& right, const EqualOptions& opts)
444       : RangeEqualsVisitor(right, 0, right.length(), 0), opts_(opts) {}
445 
Visit(const NullArray & left)446   Status Visit(const NullArray& left) {
447     ARROW_UNUSED(left);
448     result_ = true;
449     return Status::OK();
450   }
451 
Visit(const BooleanArray & left)452   Status Visit(const BooleanArray& left) {
453     const auto& right = checked_cast<const BooleanArray&>(right_);
454 
455     if (left.null_count() > 0) {
456       const uint8_t* left_data = left.values()->data();
457       const uint8_t* right_data = right.values()->data();
458 
459       for (int64_t i = 0; i < left.length(); ++i) {
460         if (left.IsValid(i) && BitUtil::GetBit(left_data, i + left.offset()) !=
461                                    BitUtil::GetBit(right_data, i + right.offset())) {
462           result_ = false;
463           return Status::OK();
464         }
465       }
466       result_ = true;
467     } else {
468       result_ = BitmapEquals(left.values()->data(), left.offset(), right.values()->data(),
469                              right.offset(), left.length());
470     }
471     return Status::OK();
472   }
473 
474   template <typename T>
475   typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value &&
476                               !std::is_base_of<FloatArray, T>::value &&
477                               !std::is_base_of<DoubleArray, T>::value &&
478                               !std::is_base_of<BooleanArray, T>::value,
479                           Status>::type
Visit(const T & left)480   Visit(const T& left) {
481     result_ = IsEqualPrimitive(left, checked_cast<const PrimitiveArray&>(right_));
482     return Status::OK();
483   }
484 
485   // TODO nan-aware specialization for half-floats
486 
Visit(const FloatArray & left)487   Status Visit(const FloatArray& left) {
488     result_ =
489         FloatingEquals<FloatType>(left, checked_cast<const FloatArray&>(right_), opts_);
490     return Status::OK();
491   }
492 
Visit(const DoubleArray & left)493   Status Visit(const DoubleArray& left) {
494     result_ =
495         FloatingEquals<DoubleType>(left, checked_cast<const DoubleArray&>(right_), opts_);
496     return Status::OK();
497   }
498 
499   template <typename ArrayType>
ValueOffsetsEqual(const ArrayType & left)500   bool ValueOffsetsEqual(const ArrayType& left) {
501     using offset_type = typename ArrayType::offset_type;
502 
503     const auto& right = checked_cast<const ArrayType&>(right_);
504 
505     if (left.offset() == 0 && right.offset() == 0) {
506       return left.value_offsets()->Equals(*right.value_offsets(),
507                                           (left.length() + 1) * sizeof(offset_type));
508     } else {
509       // One of the arrays is sliced; logic is more complicated because the
510       // value offsets are not both 0-based
511       auto left_offsets =
512           reinterpret_cast<const offset_type*>(left.value_offsets()->data()) +
513           left.offset();
514       auto right_offsets =
515           reinterpret_cast<const offset_type*>(right.value_offsets()->data()) +
516           right.offset();
517 
518       for (int64_t i = 0; i < left.length() + 1; ++i) {
519         if (left_offsets[i] - left_offsets[0] != right_offsets[i] - right_offsets[0]) {
520           return false;
521         }
522       }
523       return true;
524     }
525   }
526 
527   template <typename BinaryArrayType>
CompareBinary(const BinaryArrayType & left)528   bool CompareBinary(const BinaryArrayType& left) {
529     const auto& right = checked_cast<const BinaryArrayType&>(right_);
530 
531     bool equal_offsets = ValueOffsetsEqual<BinaryArrayType>(left);
532     if (!equal_offsets) {
533       return false;
534     }
535 
536     if (!left.value_data() && !(right.value_data())) {
537       return true;
538     }
539     if (left.value_offset(left.length()) == left.value_offset(0)) {
540       return true;
541     }
542 
543     const uint8_t* left_data = left.value_data()->data();
544     const uint8_t* right_data = right.value_data()->data();
545 
546     if (left.null_count() == 0) {
547       // Fast path for null count 0, single memcmp
548       if (left.offset() == 0 && right.offset() == 0) {
549         return std::memcmp(left_data, right_data,
550                            left.raw_value_offsets()[left.length()]) == 0;
551       } else {
552         const int64_t total_bytes =
553             left.value_offset(left.length()) - left.value_offset(0);
554         return std::memcmp(left_data + left.value_offset(0),
555                            right_data + right.value_offset(0),
556                            static_cast<size_t>(total_bytes)) == 0;
557       }
558     } else {
559       // ARROW-537: Only compare data in non-null slots
560       auto left_offsets = left.raw_value_offsets();
561       auto right_offsets = right.raw_value_offsets();
562       for (int64_t i = 0; i < left.length(); ++i) {
563         if (left.IsNull(i)) {
564           continue;
565         }
566         if (std::memcmp(left_data + left_offsets[i], right_data + right_offsets[i],
567                         left.value_length(i))) {
568           return false;
569         }
570       }
571       return true;
572     }
573   }
574 
575   template <typename ListArrayType>
CompareList(const ListArrayType & left)576   bool CompareList(const ListArrayType& left) {
577     const auto& right = checked_cast<const ListArrayType&>(right_);
578 
579     bool equal_offsets = ValueOffsetsEqual<ListArrayType>(left);
580     if (!equal_offsets) {
581       return false;
582     }
583 
584     return left.values()->RangeEquals(left.value_offset(0),
585                                       left.value_offset(left.length()),
586                                       right.value_offset(0), right.values());
587   }
588 
Visit(const BinaryArray & left)589   Status Visit(const BinaryArray& left) {
590     result_ = CompareBinary(left);
591     return Status::OK();
592   }
593 
Visit(const LargeBinaryArray & left)594   Status Visit(const LargeBinaryArray& left) {
595     result_ = CompareBinary(left);
596     return Status::OK();
597   }
598 
Visit(const ListArray & left)599   Status Visit(const ListArray& left) {
600     result_ = CompareList(left);
601     return Status::OK();
602   }
603 
Visit(const LargeListArray & left)604   Status Visit(const LargeListArray& left) {
605     result_ = CompareList(left);
606     return Status::OK();
607   }
608 
Visit(const FixedSizeListArray & left)609   Status Visit(const FixedSizeListArray& left) {
610     const auto& right = checked_cast<const FixedSizeListArray&>(right_);
611     result_ =
612         left.values()->RangeEquals(left.value_offset(0), left.value_offset(left.length()),
613                                    right.value_offset(0), right.values());
614     return Status::OK();
615   }
616 
Visit(const DictionaryArray & left)617   Status Visit(const DictionaryArray& left) {
618     const auto& right = checked_cast<const DictionaryArray&>(right_);
619     if (!left.dictionary()->Equals(right.dictionary())) {
620       result_ = false;
621     } else {
622       result_ = left.indices()->Equals(right.indices());
623     }
624     return Status::OK();
625   }
626 
627   template <typename T>
628   typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
629                           Status>::type
Visit(const T & left)630   Visit(const T& left) {
631     return RangeEqualsVisitor::Visit(left);
632   }
633 
Visit(const ExtensionArray & left)634   Status Visit(const ExtensionArray& left) {
635     result_ = (right_.type()->Equals(*left.type()) &&
636                ArrayEquals(*left.storage(),
637                            *static_cast<const ExtensionArray&>(right_).storage()));
638     return Status::OK();
639   }
640 
641  protected:
642   const EqualOptions opts_;
643 };
644 
645 class ApproxEqualsVisitor : public ArrayEqualsVisitor {
646  public:
ApproxEqualsVisitor(const Array & right,const EqualOptions & opts)647   explicit ApproxEqualsVisitor(const Array& right, const EqualOptions& opts)
648       : ArrayEqualsVisitor(right, opts) {}
649 
650   using ArrayEqualsVisitor::Visit;
651 
652   // TODO half-floats
653 
Visit(const FloatArray & left)654   Status Visit(const FloatArray& left) {
655     result_ = FloatingApproxEquals<FloatType>(
656         left, checked_cast<const FloatArray&>(right_), opts_);
657     return Status::OK();
658   }
659 
Visit(const DoubleArray & left)660   Status Visit(const DoubleArray& left) {
661     result_ = FloatingApproxEquals<DoubleType>(
662         left, checked_cast<const DoubleArray&>(right_), opts_);
663     return Status::OK();
664   }
665 };
666 
BaseDataEquals(const Array & left,const Array & right)667 static bool BaseDataEquals(const Array& left, const Array& right) {
668   if (left.length() != right.length() || left.null_count() != right.null_count() ||
669       left.type_id() != right.type_id()) {
670     return false;
671   }
672   // ARROW-2567: Ensure that not only the type id but also the type equality
673   // itself is checked.
674   if (!TypeEquals(*left.type(), *right.type(), false /* check_metadata */)) {
675     return false;
676   }
677   if (left.null_count() > 0 && left.null_count() < left.length()) {
678     return BitmapEquals(left.null_bitmap()->data(), left.offset(),
679                         right.null_bitmap()->data(), right.offset(), left.length());
680   }
681   return true;
682 }
683 
684 template <typename VISITOR, typename... Extra>
ArrayEqualsImpl(const Array & left,const Array & right,Extra &&...extra)685 inline bool ArrayEqualsImpl(const Array& left, const Array& right, Extra&&... extra) {
686   bool are_equal;
687   // The arrays are the same object
688   if (&left == &right) {
689     are_equal = true;
690   } else if (!BaseDataEquals(left, right)) {
691     are_equal = false;
692   } else if (left.length() == 0) {
693     are_equal = true;
694   } else if (left.null_count() == left.length()) {
695     are_equal = true;
696   } else {
697     VISITOR visitor(right, std::forward<Extra>(extra)...);
698     auto error = VisitArrayInline(left, &visitor);
699     if (!error.ok()) {
700       DCHECK(false) << "Arrays are not comparable: " << error.ToString();
701     }
702     are_equal = visitor.result();
703   }
704   return are_equal;
705 }
706 
707 class TypeEqualsVisitor {
708  public:
TypeEqualsVisitor(const DataType & right,bool check_metadata)709   explicit TypeEqualsVisitor(const DataType& right, bool check_metadata)
710       : right_(right), check_metadata_(check_metadata), result_(false) {}
711 
VisitChildren(const DataType & left)712   Status VisitChildren(const DataType& left) {
713     if (left.num_fields() != right_.num_fields()) {
714       result_ = false;
715       return Status::OK();
716     }
717 
718     for (int i = 0; i < left.num_fields(); ++i) {
719       if (!left.field(i)->Equals(right_.field(i), check_metadata_)) {
720         result_ = false;
721         return Status::OK();
722       }
723     }
724     result_ = true;
725     return Status::OK();
726   }
727 
728   template <typename T>
729   enable_if_t<is_null_type<T>::value || is_primitive_ctype<T>::value ||
730                   is_base_binary_type<T>::value,
731               Status>
Visit(const T &)732   Visit(const T&) {
733     result_ = true;
734     return Status::OK();
735   }
736 
737   template <typename T>
Visit(const T & left)738   enable_if_interval<T, Status> Visit(const T& left) {
739     const auto& right = checked_cast<const IntervalType&>(right_);
740     result_ = right.interval_type() == left.interval_type();
741     return Status::OK();
742   }
743 
744   template <typename T>
745   enable_if_t<is_time_type<T>::value || is_date_type<T>::value ||
746                   is_duration_type<T>::value,
747               Status>
Visit(const T & left)748   Visit(const T& left) {
749     const auto& right = checked_cast<const T&>(right_);
750     result_ = left.unit() == right.unit();
751     return Status::OK();
752   }
753 
Visit(const TimestampType & left)754   Status Visit(const TimestampType& left) {
755     const auto& right = checked_cast<const TimestampType&>(right_);
756     result_ = left.unit() == right.unit() && left.timezone() == right.timezone();
757     return Status::OK();
758   }
759 
Visit(const FixedSizeBinaryType & left)760   Status Visit(const FixedSizeBinaryType& left) {
761     const auto& right = checked_cast<const FixedSizeBinaryType&>(right_);
762     result_ = left.byte_width() == right.byte_width();
763     return Status::OK();
764   }
765 
Visit(const Decimal128Type & left)766   Status Visit(const Decimal128Type& left) {
767     const auto& right = checked_cast<const Decimal128Type&>(right_);
768     result_ = left.precision() == right.precision() && left.scale() == right.scale();
769     return Status::OK();
770   }
771 
772   template <typename T>
Visit(const T & left)773   enable_if_t<is_list_like_type<T>::value || is_struct_type<T>::value, Status> Visit(
774       const T& left) {
775     return VisitChildren(left);
776   }
777 
Visit(const MapType & left)778   Status Visit(const MapType& left) {
779     const auto& right = checked_cast<const MapType&>(right_);
780     if (left.keys_sorted() != right.keys_sorted()) {
781       result_ = false;
782       return Status::OK();
783     }
784     return VisitChildren(left);
785   }
786 
Visit(const UnionType & left)787   Status Visit(const UnionType& left) {
788     const auto& right = checked_cast<const UnionType&>(right_);
789 
790     if (left.mode() != right.mode() || left.type_codes() != right.type_codes()) {
791       result_ = false;
792       return Status::OK();
793     }
794 
795     result_ = std::equal(
796         left.fields().begin(), left.fields().end(), right.fields().begin(),
797         [this](const std::shared_ptr<Field>& l, const std::shared_ptr<Field>& r) {
798           return l->Equals(r, check_metadata_);
799         });
800     return Status::OK();
801   }
802 
Visit(const DictionaryType & left)803   Status Visit(const DictionaryType& left) {
804     const auto& right = checked_cast<const DictionaryType&>(right_);
805     result_ = left.index_type()->Equals(right.index_type()) &&
806               left.value_type()->Equals(right.value_type()) &&
807               (left.ordered() == right.ordered());
808     return Status::OK();
809   }
810 
Visit(const ExtensionType & left)811   Status Visit(const ExtensionType& left) {
812     result_ = left.ExtensionEquals(static_cast<const ExtensionType&>(right_));
813     return Status::OK();
814   }
815 
result() const816   bool result() const { return result_; }
817 
818  protected:
819   const DataType& right_;
820   bool check_metadata_;
821   bool result_;
822 };
823 
824 class ScalarEqualsVisitor {
825  public:
ScalarEqualsVisitor(const Scalar & right)826   explicit ScalarEqualsVisitor(const Scalar& right) : right_(right), result_(false) {}
827 
Visit(const NullScalar & left)828   Status Visit(const NullScalar& left) {
829     result_ = true;
830     return Status::OK();
831   }
832 
Visit(const BooleanScalar & left)833   Status Visit(const BooleanScalar& left) {
834     const auto& right = checked_cast<const BooleanScalar&>(right_);
835     result_ = left.value == right.value;
836     return Status::OK();
837   }
838 
839   template <typename T>
840   typename std::enable_if<
841       std::is_base_of<internal::PrimitiveScalar<typename T::TypeClass>, T>::value ||
842           std::is_base_of<TemporalScalar<typename T::TypeClass>, T>::value,
843       Status>::type
Visit(const T & left_)844   Visit(const T& left_) {
845     const auto& right = checked_cast<const T&>(right_);
846     result_ = right.value == left_.value;
847     return Status::OK();
848   }
849 
850   template <typename T>
851   typename std::enable_if<std::is_base_of<BaseBinaryScalar, T>::value, Status>::type
Visit(const T & left)852   Visit(const T& left) {
853     const auto& right = checked_cast<const BaseBinaryScalar&>(right_);
854     result_ = internal::SharedPtrEquals(left.value, right.value);
855     return Status::OK();
856   }
857 
Visit(const Decimal128Scalar & left)858   Status Visit(const Decimal128Scalar& left) {
859     const auto& right = checked_cast<const Decimal128Scalar&>(right_);
860     result_ = left.value == right.value;
861     return Status::OK();
862   }
863 
Visit(const ListScalar & left)864   Status Visit(const ListScalar& left) {
865     const auto& right = checked_cast<const ListScalar&>(right_);
866     result_ = internal::SharedPtrEquals(left.value, right.value);
867     return Status::OK();
868   }
869 
Visit(const LargeListScalar & left)870   Status Visit(const LargeListScalar& left) {
871     const auto& right = checked_cast<const LargeListScalar&>(right_);
872     result_ = internal::SharedPtrEquals(left.value, right.value);
873     return Status::OK();
874   }
875 
Visit(const MapScalar & left)876   Status Visit(const MapScalar& left) {
877     const auto& right = checked_cast<const MapScalar&>(right_);
878     result_ = internal::SharedPtrEquals(left.value, right.value);
879     return Status::OK();
880   }
881 
Visit(const FixedSizeListScalar & left)882   Status Visit(const FixedSizeListScalar& left) {
883     const auto& right = checked_cast<const FixedSizeListScalar&>(right_);
884     result_ = internal::SharedPtrEquals(left.value, right.value);
885     return Status::OK();
886   }
887 
Visit(const StructScalar & left)888   Status Visit(const StructScalar& left) {
889     const auto& right = checked_cast<const StructScalar&>(right_);
890 
891     if (right.value.size() != left.value.size()) {
892       result_ = false;
893     } else {
894       bool all_equals = true;
895       for (size_t i = 0; i < left.value.size() && all_equals; i++) {
896         all_equals &= internal::SharedPtrEquals(left.value[i], right.value[i]);
897       }
898       result_ = all_equals;
899     }
900 
901     return Status::OK();
902   }
903 
Visit(const UnionScalar & left)904   Status Visit(const UnionScalar& left) { return Status::NotImplemented("union"); }
905 
Visit(const DictionaryScalar & left)906   Status Visit(const DictionaryScalar& left) {
907     return Status::NotImplemented("dictionary");
908   }
909 
Visit(const ExtensionScalar & left)910   Status Visit(const ExtensionScalar& left) {
911     return Status::NotImplemented("extension");
912   }
913 
result() const914   bool result() const { return result_; }
915 
916  protected:
917   const Scalar& right_;
918   bool result_;
919 };
920 
PrintDiff(const Array & left,const Array & right,std::ostream * os)921 Status PrintDiff(const Array& left, const Array& right, std::ostream* os) {
922   if (os == nullptr) {
923     return Status::OK();
924   }
925 
926   if (!left.type()->Equals(right.type())) {
927     *os << "# Array types differed: " << *left.type() << " vs " << *right.type()
928         << std::endl;
929     return Status::OK();
930   }
931 
932   if (left.type()->id() == Type::DICTIONARY) {
933     *os << "# Dictionary arrays differed" << std::endl;
934 
935     const auto& left_dict = checked_cast<const DictionaryArray&>(left);
936     const auto& right_dict = checked_cast<const DictionaryArray&>(right);
937 
938     *os << "## dictionary diff";
939     auto pos = os->tellp();
940     RETURN_NOT_OK(PrintDiff(*left_dict.dictionary(), *right_dict.dictionary(), os));
941     if (os->tellp() == pos) {
942       *os << std::endl;
943     }
944 
945     *os << "## indices diff";
946     pos = os->tellp();
947     RETURN_NOT_OK(PrintDiff(*left_dict.indices(), *right_dict.indices(), os));
948     if (os->tellp() == pos) {
949       *os << std::endl;
950     }
951     return Status::OK();
952   }
953 
954   ARROW_ASSIGN_OR_RAISE(auto edits, Diff(left, right, default_memory_pool()));
955   ARROW_ASSIGN_OR_RAISE(auto formatter, MakeUnifiedDiffFormatter(*left.type(), os));
956   return formatter(*edits, left, right);
957 }
958 
959 }  // namespace
960 
ArrayEquals(const Array & left,const Array & right,const EqualOptions & opts)961 bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) {
962   bool are_equal = ArrayEqualsImpl<ArrayEqualsVisitor>(left, right, opts);
963   if (!are_equal) {
964     ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink()));
965   }
966   return are_equal;
967 }
968 
ArrayApproxEquals(const Array & left,const Array & right,const EqualOptions & opts)969 bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) {
970   bool are_equal = ArrayEqualsImpl<ApproxEqualsVisitor>(left, right, opts);
971   if (!are_equal) {
972     DCHECK_OK(PrintDiff(left, right, opts.diff_sink()));
973   }
974   return are_equal;
975 }
976 
ArrayRangeEquals(const Array & left,const Array & right,int64_t left_start_idx,int64_t left_end_idx,int64_t right_start_idx)977 bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
978                       int64_t left_end_idx, int64_t right_start_idx) {
979   bool are_equal;
980   if (&left == &right) {
981     are_equal = true;
982   } else if (left.type_id() != right.type_id()) {
983     are_equal = false;
984   } else if (left.length() == 0) {
985     are_equal = true;
986   } else {
987     RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx, right_start_idx);
988     auto error = VisitArrayInline(left, &visitor);
989     if (!error.ok()) {
990       DCHECK(false) << "Arrays are not comparable: " << error.ToString();
991     }
992     are_equal = visitor.result();
993   }
994   return are_equal;
995 }
996 
997 namespace {
998 
StridedIntegerTensorContentEquals(const int dim_index,int64_t left_offset,int64_t right_offset,int elem_size,const Tensor & left,const Tensor & right)999 bool StridedIntegerTensorContentEquals(const int dim_index, int64_t left_offset,
1000                                        int64_t right_offset, int elem_size,
1001                                        const Tensor& left, const Tensor& right) {
1002   const auto n = left.shape()[dim_index];
1003   const auto left_stride = left.strides()[dim_index];
1004   const auto right_stride = right.strides()[dim_index];
1005   if (dim_index == left.ndim() - 1) {
1006     for (int64_t i = 0; i < n; ++i) {
1007       if (memcmp(left.raw_data() + left_offset + i * left_stride,
1008                  right.raw_data() + right_offset + i * right_stride, elem_size) != 0) {
1009         return false;
1010       }
1011     }
1012     return true;
1013   }
1014   for (int64_t i = 0; i < n; ++i) {
1015     if (!StridedIntegerTensorContentEquals(dim_index + 1, left_offset, right_offset,
1016                                            elem_size, left, right)) {
1017       return false;
1018     }
1019     left_offset += left_stride;
1020     right_offset += right_stride;
1021   }
1022   return true;
1023 }
1024 
IntegerTensorEquals(const Tensor & left,const Tensor & right)1025 bool IntegerTensorEquals(const Tensor& left, const Tensor& right) {
1026   bool are_equal;
1027   // The arrays are the same object
1028   if (&left == &right) {
1029     are_equal = true;
1030   } else {
1031     const bool left_row_major_p = left.is_row_major();
1032     const bool left_column_major_p = left.is_column_major();
1033     const bool right_row_major_p = right.is_row_major();
1034     const bool right_column_major_p = right.is_column_major();
1035 
1036     if (!(left_row_major_p && right_row_major_p) &&
1037         !(left_column_major_p && right_column_major_p)) {
1038       const auto& type = checked_cast<const FixedWidthType&>(*left.type());
1039       are_equal =
1040           StridedIntegerTensorContentEquals(0, 0, 0, type.bit_width() / 8, left, right);
1041     } else {
1042       const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
1043       const int byte_width = size_meta.bit_width() / CHAR_BIT;
1044       DCHECK_GT(byte_width, 0);
1045 
1046       const uint8_t* left_data = left.data()->data();
1047       const uint8_t* right_data = right.data()->data();
1048 
1049       are_equal = memcmp(left_data, right_data,
1050                          static_cast<size_t>(byte_width * left.size())) == 0;
1051     }
1052   }
1053   return are_equal;
1054 }
1055 
1056 template <typename DataType>
StridedFloatTensorContentEquals(const int dim_index,int64_t left_offset,int64_t right_offset,const Tensor & left,const Tensor & right,const EqualOptions & opts)1057 bool StridedFloatTensorContentEquals(const int dim_index, int64_t left_offset,
1058                                      int64_t right_offset, const Tensor& left,
1059                                      const Tensor& right, const EqualOptions& opts) {
1060   using c_type = typename DataType::c_type;
1061   static_assert(std::is_floating_point<c_type>::value,
1062                 "DataType must be a floating point type");
1063 
1064   const auto n = left.shape()[dim_index];
1065   const auto left_stride = left.strides()[dim_index];
1066   const auto right_stride = right.strides()[dim_index];
1067   if (dim_index == left.ndim() - 1) {
1068     auto left_data = left.raw_data();
1069     auto right_data = right.raw_data();
1070     if (opts.nans_equal()) {
1071       for (int64_t i = 0; i < n; ++i) {
1072         c_type left_value =
1073             *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride);
1074         c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset +
1075                                                               i * right_stride);
1076         if (left_value != right_value &&
1077             !(std::isnan(left_value) && std::isnan(right_value))) {
1078           return false;
1079         }
1080       }
1081     } else {
1082       for (int64_t i = 0; i < n; ++i) {
1083         c_type left_value =
1084             *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride);
1085         c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset +
1086                                                               i * right_stride);
1087         if (left_value != right_value) {
1088           return false;
1089         }
1090       }
1091     }
1092     return true;
1093   }
1094   for (int64_t i = 0; i < n; ++i) {
1095     if (!StridedFloatTensorContentEquals<DataType>(dim_index + 1, left_offset,
1096                                                    right_offset, left, right, opts)) {
1097       return false;
1098     }
1099     left_offset += left_stride;
1100     right_offset += right_stride;
1101   }
1102   return true;
1103 }
1104 
1105 template <typename DataType>
FloatTensorEquals(const Tensor & left,const Tensor & right,const EqualOptions & opts)1106 bool FloatTensorEquals(const Tensor& left, const Tensor& right,
1107                        const EqualOptions& opts) {
1108   return StridedFloatTensorContentEquals<DataType>(0, 0, 0, left, right, opts);
1109 }
1110 
1111 }  // namespace
1112 
TensorEquals(const Tensor & left,const Tensor & right,const EqualOptions & opts)1113 bool TensorEquals(const Tensor& left, const Tensor& right, const EqualOptions& opts) {
1114   if (left.type_id() != right.type_id()) {
1115     return false;
1116   } else if (left.size() == 0 && right.size() == 0) {
1117     return true;
1118   } else if (left.shape() != right.shape()) {
1119     return false;
1120   }
1121 
1122   switch (left.type_id()) {
1123     // TODO: Support half-float tensors
1124     // case Type::HALF_FLOAT:
1125     case Type::FLOAT:
1126       return FloatTensorEquals<FloatType>(left, right, opts);
1127 
1128     case Type::DOUBLE:
1129       return FloatTensorEquals<DoubleType>(left, right, opts);
1130 
1131     default:
1132       return IntegerTensorEquals(left, right);
1133   }
1134 }
1135 
1136 namespace {
1137 
1138 template <typename LeftSparseIndexType, typename RightSparseIndexType>
1139 struct SparseTensorEqualsImpl {
Comparearrow::__anon6163ced00811::SparseTensorEqualsImpl1140   static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
1141                       const SparseTensorImpl<RightSparseIndexType>& right,
1142                       const EqualOptions&) {
1143     // TODO(mrkn): should we support the equality among different formats?
1144     return false;
1145   }
1146 };
1147 
IntegerSparseTensorDataEquals(const uint8_t * left_data,const uint8_t * right_data,const int byte_width,const int64_t length)1148 bool IntegerSparseTensorDataEquals(const uint8_t* left_data, const uint8_t* right_data,
1149                                    const int byte_width, const int64_t length) {
1150   if (left_data == right_data) {
1151     return true;
1152   }
1153   return memcmp(left_data, right_data, static_cast<size_t>(byte_width * length)) == 0;
1154 }
1155 
1156 template <typename DataType>
FloatSparseTensorDataEquals(const typename DataType::c_type * left_data,const typename DataType::c_type * right_data,const int64_t length,const EqualOptions & opts)1157 bool FloatSparseTensorDataEquals(const typename DataType::c_type* left_data,
1158                                  const typename DataType::c_type* right_data,
1159                                  const int64_t length, const EqualOptions& opts) {
1160   using c_type = typename DataType::c_type;
1161   static_assert(std::is_floating_point<c_type>::value,
1162                 "DataType must be a floating point type");
1163   if (opts.nans_equal()) {
1164     if (left_data == right_data) {
1165       return true;
1166     }
1167 
1168     for (int64_t i = 0; i < length; ++i) {
1169       const auto left = left_data[i];
1170       const auto right = right_data[i];
1171       if (left != right && !(std::isnan(left) && std::isnan(right))) {
1172         return false;
1173       }
1174     }
1175   } else {
1176     for (int64_t i = 0; i < length; ++i) {
1177       if (left_data[i] != right_data[i]) {
1178         return false;
1179       }
1180     }
1181   }
1182   return true;
1183 }
1184 
1185 template <typename SparseIndexType>
1186 struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
Comparearrow::__anon6163ced00811::SparseTensorEqualsImpl1187   static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
1188                       const SparseTensorImpl<SparseIndexType>& right,
1189                       const EqualOptions& opts) {
1190     DCHECK(left.type()->id() == right.type()->id());
1191     DCHECK(left.shape() == right.shape());
1192 
1193     const auto length = left.non_zero_length();
1194     DCHECK(length == right.non_zero_length());
1195 
1196     const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index());
1197     const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index());
1198 
1199     if (!left_index.Equals(right_index)) {
1200       return false;
1201     }
1202 
1203     const auto& size_meta = checked_cast<const FixedWidthType&>(*left.type());
1204     const int byte_width = size_meta.bit_width() / CHAR_BIT;
1205     DCHECK_GT(byte_width, 0);
1206 
1207     const uint8_t* left_data = left.data()->data();
1208     const uint8_t* right_data = right.data()->data();
1209     switch (left.type()->id()) {
1210       // TODO: Support half-float tensors
1211       // case Type::HALF_FLOAT:
1212       case Type::FLOAT:
1213         return FloatSparseTensorDataEquals<FloatType>(
1214             reinterpret_cast<const float*>(left_data),
1215             reinterpret_cast<const float*>(right_data), length, opts);
1216 
1217       case Type::DOUBLE:
1218         return FloatSparseTensorDataEquals<DoubleType>(
1219             reinterpret_cast<const double*>(left_data),
1220             reinterpret_cast<const double*>(right_data), length, opts);
1221 
1222       default:  // Integer cases
1223         return IntegerSparseTensorDataEquals(left_data, right_data, byte_width, length);
1224     }
1225   }
1226 };
1227 
1228 template <typename SparseIndexType>
SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType> & left,const SparseTensor & right,const EqualOptions & opts)1229 inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
1230                                            const SparseTensor& right,
1231                                            const EqualOptions& opts) {
1232   switch (right.format_id()) {
1233     case SparseTensorFormat::COO: {
1234       const auto& right_coo =
1235           checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
1236       return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(
1237           left, right_coo, opts);
1238     }
1239 
1240     case SparseTensorFormat::CSR: {
1241       const auto& right_csr =
1242           checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
1243       return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(
1244           left, right_csr, opts);
1245     }
1246 
1247     case SparseTensorFormat::CSC: {
1248       const auto& right_csc =
1249           checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(right);
1250       return SparseTensorEqualsImpl<SparseIndexType, SparseCSCIndex>::Compare(
1251           left, right_csc, opts);
1252     }
1253 
1254     case SparseTensorFormat::CSF: {
1255       const auto& right_csf =
1256           checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(right);
1257       return SparseTensorEqualsImpl<SparseIndexType, SparseCSFIndex>::Compare(
1258           left, right_csf, opts);
1259     }
1260 
1261     default:
1262       return false;
1263   }
1264 }
1265 
1266 }  // namespace
1267 
SparseTensorEquals(const SparseTensor & left,const SparseTensor & right,const EqualOptions & opts)1268 bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right,
1269                         const EqualOptions& opts) {
1270   if (left.type()->id() != right.type()->id()) {
1271     return false;
1272   } else if (left.size() == 0 && right.size() == 0) {
1273     return true;
1274   } else if (left.shape() != right.shape()) {
1275     return false;
1276   } else if (left.non_zero_length() != right.non_zero_length()) {
1277     return false;
1278   }
1279 
1280   switch (left.format_id()) {
1281     case SparseTensorFormat::COO: {
1282       const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
1283       return SparseTensorEqualsImplDispatch(left_coo, right, opts);
1284     }
1285 
1286     case SparseTensorFormat::CSR: {
1287       const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
1288       return SparseTensorEqualsImplDispatch(left_csr, right, opts);
1289     }
1290 
1291     case SparseTensorFormat::CSC: {
1292       const auto& left_csc = checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(left);
1293       return SparseTensorEqualsImplDispatch(left_csc, right, opts);
1294     }
1295 
1296     case SparseTensorFormat::CSF: {
1297       const auto& left_csf = checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(left);
1298       return SparseTensorEqualsImplDispatch(left_csf, right, opts);
1299     }
1300 
1301     default:
1302       return false;
1303   }
1304 }
1305 
TypeEquals(const DataType & left,const DataType & right,bool check_metadata)1306 bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) {
1307   // The arrays are the same object
1308   if (&left == &right) {
1309     return true;
1310   } else if (left.id() != right.id()) {
1311     return false;
1312   } else {
1313     // First try to compute fingerprints
1314     if (check_metadata) {
1315       const auto& left_metadata_fp = left.metadata_fingerprint();
1316       const auto& right_metadata_fp = right.metadata_fingerprint();
1317       if (left_metadata_fp != right_metadata_fp) {
1318         return false;
1319       }
1320     }
1321 
1322     const auto& left_fp = left.fingerprint();
1323     const auto& right_fp = right.fingerprint();
1324     if (!left_fp.empty() && !right_fp.empty()) {
1325       return left_fp == right_fp;
1326     }
1327 
1328     // TODO remove check_metadata here?
1329     TypeEqualsVisitor visitor(right, check_metadata);
1330     auto error = VisitTypeInline(left, &visitor);
1331     if (!error.ok()) {
1332       DCHECK(false) << "Types are not comparable: " << error.ToString();
1333     }
1334     return visitor.result();
1335   }
1336 }
1337 
ScalarEquals(const Scalar & left,const Scalar & right)1338 bool ScalarEquals(const Scalar& left, const Scalar& right) {
1339   bool are_equal = false;
1340   if (&left == &right) {
1341     are_equal = true;
1342   } else if (!left.type->Equals(right.type)) {
1343     are_equal = false;
1344   } else if (left.is_valid != right.is_valid) {
1345     are_equal = false;
1346   } else {
1347     ScalarEqualsVisitor visitor(right);
1348     auto error = VisitScalarInline(left, &visitor);
1349     DCHECK_OK(error);
1350     are_equal = visitor.result();
1351   }
1352   return are_equal;
1353 }
1354 
1355 }  // namespace arrow
1356