1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/testing/gtest_util.h"
19 
20 #include "arrow/testing/extension_type.h"
21 
22 #ifndef _WIN32
23 #include <sys/stat.h>  // IWYU pragma: keep
24 #include <sys/wait.h>  // IWYU pragma: keep
25 #include <unistd.h>    // IWYU pragma: keep
26 #endif
27 
28 #include <algorithm>
29 #include <chrono>
30 #include <condition_variable>
31 #include <cstdint>
32 #include <cstdlib>
33 #include <iostream>
34 #include <limits>
35 #include <locale>
36 #include <memory>
37 #include <mutex>
38 #include <sstream>
39 #include <stdexcept>
40 #include <string>
41 #include <thread>
42 #include <vector>
43 
44 #include "arrow/array.h"
45 #include "arrow/buffer.h"
46 #include "arrow/datum.h"
47 #include "arrow/ipc/json_simple.h"
48 #include "arrow/pretty_print.h"
49 #include "arrow/status.h"
50 #include "arrow/table.h"
51 #include "arrow/type.h"
52 #include "arrow/util/checked_cast.h"
53 #include "arrow/util/future.h"
54 #include "arrow/util/io_util.h"
55 #include "arrow/util/logging.h"
56 #include "arrow/util/windows_compatibility.h"
57 
58 namespace arrow {
59 
60 using internal::checked_cast;
61 using internal::checked_pointer_cast;
62 
AllTypeIds()63 std::vector<Type::type> AllTypeIds() {
64   return {Type::NA,
65           Type::BOOL,
66           Type::INT8,
67           Type::INT16,
68           Type::INT32,
69           Type::INT64,
70           Type::UINT8,
71           Type::UINT16,
72           Type::UINT32,
73           Type::UINT64,
74           Type::HALF_FLOAT,
75           Type::FLOAT,
76           Type::DOUBLE,
77           Type::DECIMAL128,
78           Type::DECIMAL256,
79           Type::DATE32,
80           Type::DATE64,
81           Type::TIME32,
82           Type::TIME64,
83           Type::TIMESTAMP,
84           Type::INTERVAL_DAY_TIME,
85           Type::INTERVAL_MONTHS,
86           Type::DURATION,
87           Type::STRING,
88           Type::BINARY,
89           Type::LARGE_STRING,
90           Type::LARGE_BINARY,
91           Type::FIXED_SIZE_BINARY,
92           Type::STRUCT,
93           Type::LIST,
94           Type::LARGE_LIST,
95           Type::FIXED_SIZE_LIST,
96           Type::MAP,
97           Type::DENSE_UNION,
98           Type::SPARSE_UNION,
99           Type::DICTIONARY,
100           Type::EXTENSION,
101           Type::INTERVAL_MONTH_DAY_NANO};
102 }
103 
104 template <typename T, typename CompareFunctor>
AssertTsSame(const T & expected,const T & actual,CompareFunctor && compare)105 void AssertTsSame(const T& expected, const T& actual, CompareFunctor&& compare) {
106   if (!compare(actual, expected)) {
107     std::stringstream pp_expected;
108     std::stringstream pp_actual;
109     ::arrow::PrettyPrintOptions options(/*indent=*/2);
110     options.window = 50;
111     ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
112     ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
113     FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
114   }
115 }
116 
117 template <typename CompareFunctor>
AssertArraysEqualWith(const Array & expected,const Array & actual,bool verbose,CompareFunctor && compare)118 void AssertArraysEqualWith(const Array& expected, const Array& actual, bool verbose,
119                            CompareFunctor&& compare) {
120   std::stringstream diff;
121   if (!compare(expected, actual, &diff)) {
122     if (expected.data()->null_count != actual.data()->null_count) {
123       diff << "Null counts differ. Expected " << expected.data()->null_count
124            << " but was " << actual.data()->null_count << "\n";
125     }
126     if (verbose) {
127       ::arrow::PrettyPrintOptions options(/*indent=*/2);
128       options.window = 50;
129       diff << "Expected:\n";
130       ARROW_EXPECT_OK(PrettyPrint(expected, options, &diff));
131       diff << "\nActual:\n";
132       ARROW_EXPECT_OK(PrettyPrint(actual, options, &diff));
133     }
134     FAIL() << diff.str();
135   }
136 }
137 
AssertArraysEqual(const Array & expected,const Array & actual,bool verbose,const EqualOptions & options)138 void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose,
139                        const EqualOptions& options) {
140   return AssertArraysEqualWith(
141       expected, actual, verbose,
142       [&](const Array& expected, const Array& actual, std::stringstream* diff) {
143         return expected.Equals(actual, options.diff_sink(diff));
144       });
145 }
146 
AssertArraysApproxEqual(const Array & expected,const Array & actual,bool verbose,const EqualOptions & options)147 void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose,
148                              const EqualOptions& options) {
149   return AssertArraysEqualWith(
150       expected, actual, verbose,
151       [&](const Array& expected, const Array& actual, std::stringstream* diff) {
152         return expected.ApproxEquals(actual, options.diff_sink(diff));
153       });
154 }
155 
AssertScalarsEqual(const Scalar & expected,const Scalar & actual,bool verbose,const EqualOptions & options)156 void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose,
157                         const EqualOptions& options) {
158   if (!expected.Equals(actual, options)) {
159     std::stringstream diff;
160     if (verbose) {
161       diff << "Expected:\n" << expected.ToString();
162       diff << "\nActual:\n" << actual.ToString();
163     }
164     FAIL() << diff.str();
165   }
166 }
167 
AssertScalarsApproxEqual(const Scalar & expected,const Scalar & actual,bool verbose,const EqualOptions & options)168 void AssertScalarsApproxEqual(const Scalar& expected, const Scalar& actual, bool verbose,
169                               const EqualOptions& options) {
170   if (!expected.ApproxEquals(actual, options)) {
171     std::stringstream diff;
172     if (verbose) {
173       diff << "Expected:\n" << expected.ToString();
174       diff << "\nActual:\n" << actual.ToString();
175     }
176     FAIL() << diff.str();
177   }
178 }
179 
AssertBatchesEqual(const RecordBatch & expected,const RecordBatch & actual,bool check_metadata)180 void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual,
181                         bool check_metadata) {
182   AssertTsSame(expected, actual,
183                [&](const RecordBatch& expected, const RecordBatch& actual) {
184                  return expected.Equals(actual, check_metadata);
185                });
186 }
187 
AssertBatchesApproxEqual(const RecordBatch & expected,const RecordBatch & actual)188 void AssertBatchesApproxEqual(const RecordBatch& expected, const RecordBatch& actual) {
189   AssertTsSame(expected, actual,
190                [&](const RecordBatch& expected, const RecordBatch& actual) {
191                  return expected.ApproxEquals(actual);
192                });
193 }
194 
AssertChunkedEqual(const ChunkedArray & expected,const ChunkedArray & actual)195 void AssertChunkedEqual(const ChunkedArray& expected, const ChunkedArray& actual) {
196   ASSERT_EQ(expected.num_chunks(), actual.num_chunks()) << "# chunks unequal";
197   if (!actual.Equals(expected)) {
198     std::stringstream diff;
199     for (int i = 0; i < actual.num_chunks(); ++i) {
200       auto c1 = actual.chunk(i);
201       auto c2 = expected.chunk(i);
202       diff << "# chunk " << i << std::endl;
203       ARROW_IGNORE_EXPR(c1->Equals(c2, EqualOptions().diff_sink(&diff)));
204     }
205     FAIL() << diff.str();
206   }
207 }
208 
AssertChunkedEqual(const ChunkedArray & actual,const ArrayVector & expected)209 void AssertChunkedEqual(const ChunkedArray& actual, const ArrayVector& expected) {
210   AssertChunkedEqual(ChunkedArray(expected, actual.type()), actual);
211 }
212 
AssertChunkedEquivalent(const ChunkedArray & expected,const ChunkedArray & actual)213 void AssertChunkedEquivalent(const ChunkedArray& expected, const ChunkedArray& actual) {
214   // XXX: AssertChunkedEqual in gtest_util.h does not permit the chunk layouts
215   // to be different
216   if (!actual.Equals(expected)) {
217     std::stringstream pp_expected;
218     std::stringstream pp_actual;
219     ::arrow::PrettyPrintOptions options(/*indent=*/2);
220     options.window = 50;
221     ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
222     ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
223     FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
224   }
225 }
226 
AssertChunkedApproxEquivalent(const ChunkedArray & expected,const ChunkedArray & actual,const EqualOptions & equal_options)227 void AssertChunkedApproxEquivalent(const ChunkedArray& expected,
228                                    const ChunkedArray& actual,
229                                    const EqualOptions& equal_options) {
230   if (!actual.ApproxEquals(expected, equal_options)) {
231     std::stringstream pp_expected;
232     std::stringstream pp_actual;
233     ::arrow::PrettyPrintOptions options(/*indent=*/2);
234     options.window = 50;
235     ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
236     ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
237     FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
238   }
239 }
240 
AssertBufferEqual(const Buffer & buffer,const std::vector<uint8_t> & expected)241 void AssertBufferEqual(const Buffer& buffer, const std::vector<uint8_t>& expected) {
242   ASSERT_EQ(static_cast<size_t>(buffer.size()), expected.size())
243       << "Mismatching buffer size";
244   const uint8_t* buffer_data = buffer.data();
245   for (size_t i = 0; i < expected.size(); ++i) {
246     ASSERT_EQ(buffer_data[i], expected[i]);
247   }
248 }
249 
AssertBufferEqual(const Buffer & buffer,const std::string & expected)250 void AssertBufferEqual(const Buffer& buffer, const std::string& expected) {
251   ASSERT_EQ(static_cast<size_t>(buffer.size()), expected.length())
252       << "Mismatching buffer size";
253   const uint8_t* buffer_data = buffer.data();
254   for (size_t i = 0; i < expected.size(); ++i) {
255     ASSERT_EQ(buffer_data[i], expected[i]);
256   }
257 }
258 
AssertBufferEqual(const Buffer & buffer,const Buffer & expected)259 void AssertBufferEqual(const Buffer& buffer, const Buffer& expected) {
260   ASSERT_EQ(buffer.size(), expected.size()) << "Mismatching buffer size";
261   ASSERT_TRUE(buffer.Equals(expected));
262 }
263 
264 template <typename T>
ToStringWithMetadata(const T & t,bool show_metadata)265 std::string ToStringWithMetadata(const T& t, bool show_metadata) {
266   return t.ToString(show_metadata);
267 }
268 
ToStringWithMetadata(const DataType & t,bool show_metadata)269 std::string ToStringWithMetadata(const DataType& t, bool show_metadata) {
270   return t.ToString();
271 }
272 
273 template <typename T>
AssertFingerprintablesEqual(const T & left,const T & right,bool check_metadata,const char * types_plural)274 void AssertFingerprintablesEqual(const T& left, const T& right, bool check_metadata,
275                                  const char* types_plural) {
276   ASSERT_TRUE(left.Equals(right, check_metadata))
277       << types_plural << " '" << ToStringWithMetadata(left, check_metadata) << "' and '"
278       << ToStringWithMetadata(right, check_metadata) << "' should have compared equal";
279   auto lfp = left.fingerprint();
280   auto rfp = right.fingerprint();
281   // Note: all types tested in this file should implement fingerprinting,
282   // except extension types.
283   if (check_metadata) {
284     lfp += left.metadata_fingerprint();
285     rfp += right.metadata_fingerprint();
286   }
287   ASSERT_EQ(lfp, rfp) << "Fingerprints for " << types_plural << " '"
288                       << ToStringWithMetadata(left, check_metadata) << "' and '"
289                       << ToStringWithMetadata(right, check_metadata)
290                       << "' should have compared equal";
291 }
292 
293 template <typename T>
AssertFingerprintablesEqual(const std::shared_ptr<T> & left,const std::shared_ptr<T> & right,bool check_metadata,const char * types_plural)294 void AssertFingerprintablesEqual(const std::shared_ptr<T>& left,
295                                  const std::shared_ptr<T>& right, bool check_metadata,
296                                  const char* types_plural) {
297   ASSERT_NE(left, nullptr);
298   ASSERT_NE(right, nullptr);
299   AssertFingerprintablesEqual(*left, *right, check_metadata, types_plural);
300 }
301 
302 template <typename T>
AssertFingerprintablesNotEqual(const T & left,const T & right,bool check_metadata,const char * types_plural)303 void AssertFingerprintablesNotEqual(const T& left, const T& right, bool check_metadata,
304                                     const char* types_plural) {
305   ASSERT_FALSE(left.Equals(right, check_metadata))
306       << types_plural << " '" << ToStringWithMetadata(left, check_metadata) << "' and '"
307       << ToStringWithMetadata(right, check_metadata) << "' should have compared unequal";
308   auto lfp = left.fingerprint();
309   auto rfp = right.fingerprint();
310   // Note: all types tested in this file should implement fingerprinting,
311   // except extension types.
312   if (lfp != "" && rfp != "") {
313     if (check_metadata) {
314       lfp += left.metadata_fingerprint();
315       rfp += right.metadata_fingerprint();
316     }
317     ASSERT_NE(lfp, rfp) << "Fingerprints for " << types_plural << " '"
318                         << ToStringWithMetadata(left, check_metadata) << "' and '"
319                         << ToStringWithMetadata(right, check_metadata)
320                         << "' should have compared unequal";
321   }
322 }
323 
324 template <typename T>
AssertFingerprintablesNotEqual(const std::shared_ptr<T> & left,const std::shared_ptr<T> & right,bool check_metadata,const char * types_plural)325 void AssertFingerprintablesNotEqual(const std::shared_ptr<T>& left,
326                                     const std::shared_ptr<T>& right, bool check_metadata,
327                                     const char* types_plural) {
328   ASSERT_NE(left, nullptr);
329   ASSERT_NE(right, nullptr);
330   AssertFingerprintablesNotEqual(*left, *right, check_metadata, types_plural);
331 }
332 
333 #define ASSERT_EQUAL_IMPL(NAME, TYPE, PLURAL)                                            \
334   void Assert##NAME##Equal(const TYPE& left, const TYPE& right, bool check_metadata) {   \
335     AssertFingerprintablesEqual(left, right, check_metadata, PLURAL);                    \
336   }                                                                                      \
337                                                                                          \
338   void Assert##NAME##Equal(const std::shared_ptr<TYPE>& left,                            \
339                            const std::shared_ptr<TYPE>& right, bool check_metadata) {    \
340     AssertFingerprintablesEqual(left, right, check_metadata, PLURAL);                    \
341   }                                                                                      \
342                                                                                          \
343   void Assert##NAME##NotEqual(const TYPE& left, const TYPE& right,                       \
344                               bool check_metadata) {                                     \
345     AssertFingerprintablesNotEqual(left, right, check_metadata, PLURAL);                 \
346   }                                                                                      \
347   void Assert##NAME##NotEqual(const std::shared_ptr<TYPE>& left,                         \
348                               const std::shared_ptr<TYPE>& right, bool check_metadata) { \
349     AssertFingerprintablesNotEqual(left, right, check_metadata, PLURAL);                 \
350   }
351 
352 ASSERT_EQUAL_IMPL(Type, DataType, "types")
353 ASSERT_EQUAL_IMPL(Field, Field, "fields")
354 ASSERT_EQUAL_IMPL(Schema, Schema, "schemas")
355 #undef ASSERT_EQUAL_IMPL
356 
AssertDatumsEqual(const Datum & expected,const Datum & actual,bool verbose)357 void AssertDatumsEqual(const Datum& expected, const Datum& actual, bool verbose) {
358   ASSERT_EQ(expected.kind(), actual.kind())
359       << "expected:" << expected.ToString() << " got:" << actual.ToString();
360 
361   switch (expected.kind()) {
362     case Datum::SCALAR:
363       AssertScalarsEqual(*expected.scalar(), *actual.scalar(), verbose);
364       break;
365     case Datum::ARRAY: {
366       auto expected_array = expected.make_array();
367       auto actual_array = actual.make_array();
368       AssertArraysEqual(*expected_array, *actual_array, verbose);
369     } break;
370     case Datum::CHUNKED_ARRAY:
371       AssertChunkedEquivalent(*expected.chunked_array(), *actual.chunked_array());
372       break;
373     default:
374       // TODO: Implement better print
375       ASSERT_TRUE(actual.Equals(expected));
376       break;
377   }
378 }
379 
AssertDatumsApproxEqual(const Datum & expected,const Datum & actual,bool verbose,const EqualOptions & options)380 void AssertDatumsApproxEqual(const Datum& expected, const Datum& actual, bool verbose,
381                              const EqualOptions& options) {
382   ASSERT_EQ(expected.kind(), actual.kind())
383       << "expected:" << expected.ToString() << " got:" << actual.ToString();
384 
385   switch (expected.kind()) {
386     case Datum::SCALAR:
387       AssertScalarsApproxEqual(*expected.scalar(), *actual.scalar(), verbose, options);
388       break;
389     case Datum::ARRAY: {
390       auto expected_array = expected.make_array();
391       auto actual_array = actual.make_array();
392       AssertArraysApproxEqual(*expected_array, *actual_array, verbose, options);
393       break;
394     }
395     case Datum::CHUNKED_ARRAY: {
396       auto expected_array = expected.chunked_array();
397       auto actual_array = actual.chunked_array();
398       AssertChunkedApproxEquivalent(*expected_array, *actual_array, options);
399       break;
400     }
401     default:
402       // TODO: Implement better print
403       ASSERT_TRUE(actual.Equals(expected));
404       break;
405   }
406 }
407 
ArrayFromJSON(const std::shared_ptr<DataType> & type,util::string_view json)408 std::shared_ptr<Array> ArrayFromJSON(const std::shared_ptr<DataType>& type,
409                                      util::string_view json) {
410   std::shared_ptr<Array> out;
411   ABORT_NOT_OK(ipc::internal::json::ArrayFromJSON(type, json, &out));
412   return out;
413 }
414 
DictArrayFromJSON(const std::shared_ptr<DataType> & type,util::string_view indices_json,util::string_view dictionary_json)415 std::shared_ptr<Array> DictArrayFromJSON(const std::shared_ptr<DataType>& type,
416                                          util::string_view indices_json,
417                                          util::string_view dictionary_json) {
418   std::shared_ptr<Array> out;
419   ABORT_NOT_OK(
420       ipc::internal::json::DictArrayFromJSON(type, indices_json, dictionary_json, &out));
421   return out;
422 }
423 
ChunkedArrayFromJSON(const std::shared_ptr<DataType> & type,const std::vector<std::string> & json)424 std::shared_ptr<ChunkedArray> ChunkedArrayFromJSON(const std::shared_ptr<DataType>& type,
425                                                    const std::vector<std::string>& json) {
426   ArrayVector out_chunks;
427   for (const std::string& chunk_json : json) {
428     out_chunks.push_back(ArrayFromJSON(type, chunk_json));
429   }
430   return std::make_shared<ChunkedArray>(std::move(out_chunks), type);
431 }
432 
RecordBatchFromJSON(const std::shared_ptr<Schema> & schema,util::string_view json)433 std::shared_ptr<RecordBatch> RecordBatchFromJSON(const std::shared_ptr<Schema>& schema,
434                                                  util::string_view json) {
435   // Parse as a StructArray
436   auto struct_type = struct_(schema->fields());
437   std::shared_ptr<Array> struct_array = ArrayFromJSON(struct_type, json);
438 
439   // Convert StructArray to RecordBatch
440   return *RecordBatch::FromStructArray(struct_array);
441 }
442 
ScalarFromJSON(const std::shared_ptr<DataType> & type,util::string_view json)443 std::shared_ptr<Scalar> ScalarFromJSON(const std::shared_ptr<DataType>& type,
444                                        util::string_view json) {
445   std::shared_ptr<Scalar> out;
446   ABORT_NOT_OK(ipc::internal::json::ScalarFromJSON(type, json, &out));
447   return out;
448 }
449 
DictScalarFromJSON(const std::shared_ptr<DataType> & type,util::string_view index_json,util::string_view dictionary_json)450 std::shared_ptr<Scalar> DictScalarFromJSON(const std::shared_ptr<DataType>& type,
451                                            util::string_view index_json,
452                                            util::string_view dictionary_json) {
453   std::shared_ptr<Scalar> out;
454   ABORT_NOT_OK(
455       ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out));
456   return out;
457 }
458 
TableFromJSON(const std::shared_ptr<Schema> & schema,const std::vector<std::string> & json)459 std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>& schema,
460                                      const std::vector<std::string>& json) {
461   std::vector<std::shared_ptr<RecordBatch>> batches;
462   for (const std::string& batch_json : json) {
463     batches.push_back(RecordBatchFromJSON(schema, batch_json));
464   }
465   return *Table::FromRecordBatches(schema, std::move(batches));
466 }
467 
PrintArrayDiff(const ChunkedArray & expected,const ChunkedArray & actual)468 Result<util::optional<std::string>> PrintArrayDiff(const ChunkedArray& expected,
469                                                    const ChunkedArray& actual) {
470   if (actual.Equals(expected)) {
471     return util::nullopt;
472   }
473 
474   std::stringstream ss;
475   if (expected.length() != actual.length()) {
476     ss << "Expected length " << expected.length() << " but was actually "
477        << actual.length();
478     return ss.str();
479   }
480 
481   PrettyPrintOptions options(/*indent=*/2);
482   options.window = 50;
483   RETURN_NOT_OK(internal::ApplyBinaryChunked(
484       actual, expected,
485       [&](const Array& left_piece, const Array& right_piece, int64_t position) {
486         std::stringstream diff;
487         if (!left_piece.Equals(right_piece, EqualOptions().diff_sink(&diff))) {
488           ss << "Unequal at absolute position " << position << "\n" << diff.str();
489           ss << "Expected:\n";
490           ARROW_EXPECT_OK(PrettyPrint(right_piece, options, &ss));
491           ss << "\nActual:\n";
492           ARROW_EXPECT_OK(PrettyPrint(left_piece, options, &ss));
493         }
494         return Status::OK();
495       }));
496   return ss.str();
497 }
498 
AssertTablesEqual(const Table & expected,const Table & actual,bool same_chunk_layout,bool combine_chunks)499 void AssertTablesEqual(const Table& expected, const Table& actual, bool same_chunk_layout,
500                        bool combine_chunks) {
501   ASSERT_EQ(expected.num_columns(), actual.num_columns());
502 
503   if (combine_chunks) {
504     auto pool = default_memory_pool();
505     ASSERT_OK_AND_ASSIGN(auto new_expected, expected.CombineChunks(pool));
506     ASSERT_OK_AND_ASSIGN(auto new_actual, actual.CombineChunks(pool));
507 
508     AssertTablesEqual(*new_expected, *new_actual, false, false);
509     return;
510   }
511 
512   if (same_chunk_layout) {
513     for (int i = 0; i < actual.num_columns(); ++i) {
514       AssertChunkedEqual(*expected.column(i), *actual.column(i));
515     }
516   } else {
517     std::stringstream ss;
518     for (int i = 0; i < actual.num_columns(); ++i) {
519       auto actual_col = actual.column(i);
520       auto expected_col = expected.column(i);
521 
522       ASSERT_OK_AND_ASSIGN(auto diff, PrintArrayDiff(*expected_col, *actual_col));
523       if (diff.has_value()) {
524         FAIL() << *diff;
525       }
526     }
527   }
528 }
529 
530 template <typename CompareFunctor>
CompareBatchWith(const RecordBatch & left,const RecordBatch & right,bool compare_metadata,CompareFunctor && compare)531 void CompareBatchWith(const RecordBatch& left, const RecordBatch& right,
532                       bool compare_metadata, CompareFunctor&& compare) {
533   if (!left.schema()->Equals(*right.schema(), compare_metadata)) {
534     FAIL() << "Left schema: " << left.schema()->ToString(compare_metadata)
535            << "\nRight schema: " << right.schema()->ToString(compare_metadata);
536   }
537   ASSERT_EQ(left.num_columns(), right.num_columns())
538       << left.schema()->ToString() << " result: " << right.schema()->ToString();
539   ASSERT_EQ(left.num_rows(), right.num_rows());
540   for (int i = 0; i < left.num_columns(); ++i) {
541     if (!compare(*left.column(i), *right.column(i))) {
542       std::stringstream ss;
543       ss << "Idx: " << i << " Name: " << left.column_name(i);
544       ss << std::endl << "Left: ";
545       ASSERT_OK(PrettyPrint(*left.column(i), 0, &ss));
546       ss << std::endl << "Right: ";
547       ASSERT_OK(PrettyPrint(*right.column(i), 0, &ss));
548       FAIL() << ss.str();
549     }
550   }
551 }
552 
CompareBatch(const RecordBatch & left,const RecordBatch & right,bool compare_metadata)553 void CompareBatch(const RecordBatch& left, const RecordBatch& right,
554                   bool compare_metadata) {
555   return CompareBatchWith(
556       left, right, compare_metadata,
557       [](const Array& left, const Array& right) { return left.Equals(right); });
558 }
559 
ApproxCompareBatch(const RecordBatch & left,const RecordBatch & right,bool compare_metadata)560 void ApproxCompareBatch(const RecordBatch& left, const RecordBatch& right,
561                         bool compare_metadata) {
562   return CompareBatchWith(
563       left, right, compare_metadata,
564       [](const Array& left, const Array& right) { return left.ApproxEquals(right); });
565 }
566 
TweakValidityBit(const std::shared_ptr<Array> & array,int64_t index,bool validity)567 std::shared_ptr<Array> TweakValidityBit(const std::shared_ptr<Array>& array,
568                                         int64_t index, bool validity) {
569   auto data = array->data()->Copy();
570   if (data->buffers[0] == nullptr) {
571     data->buffers[0] = *AllocateBitmap(data->length);
572     BitUtil::SetBitsTo(data->buffers[0]->mutable_data(), 0, data->length, true);
573   }
574   BitUtil::SetBitTo(data->buffers[0]->mutable_data(), index, validity);
575   data->null_count = kUnknownNullCount;
576   // Need to return a new array, because Array caches the null bitmap pointer
577   return MakeArray(data);
578 }
579 
LocaleExists(const char * locale)580 bool LocaleExists(const char* locale) {
581   try {
582     std::locale loc(locale);
583     return true;
584   } catch (std::runtime_error&) {
585     return false;
586   }
587 }
588 
589 class LocaleGuard::Impl {
590  public:
Impl(const char * new_locale)591   explicit Impl(const char* new_locale) : global_locale_(std::locale()) {
592     try {
593       std::locale::global(std::locale(new_locale));
594     } catch (std::runtime_error&) {
595       ARROW_LOG(WARNING) << "Locale unavailable (ignored): '" << new_locale << "'";
596     }
597   }
598 
~Impl()599   ~Impl() { std::locale::global(global_locale_); }
600 
601  protected:
602   std::locale global_locale_;
603 };
604 
LocaleGuard(const char * new_locale)605 LocaleGuard::LocaleGuard(const char* new_locale) : impl_(new Impl(new_locale)) {}
606 
~LocaleGuard()607 LocaleGuard::~LocaleGuard() {}
608 
EnvVarGuard(const std::string & name,const std::string & value)609 EnvVarGuard::EnvVarGuard(const std::string& name, const std::string& value)
610     : name_(name) {
611   auto maybe_value = arrow::internal::GetEnvVar(name);
612   if (maybe_value.ok()) {
613     was_set_ = true;
614     old_value_ = *std::move(maybe_value);
615   } else {
616     was_set_ = false;
617   }
618   ARROW_CHECK_OK(arrow::internal::SetEnvVar(name, value));
619 }
620 
~EnvVarGuard()621 EnvVarGuard::~EnvVarGuard() {
622   if (was_set_) {
623     ARROW_CHECK_OK(arrow::internal::SetEnvVar(name_, old_value_));
624   } else {
625     ARROW_CHECK_OK(arrow::internal::DelEnvVar(name_));
626   }
627 }
628 
629 struct SignalHandlerGuard::Impl {
630   int signum_;
631   internal::SignalHandler old_handler_;
632 
Implarrow::SignalHandlerGuard::Impl633   Impl(int signum, const internal::SignalHandler& handler)
634       : signum_(signum), old_handler_(*internal::SetSignalHandler(signum, handler)) {}
635 
~Implarrow::SignalHandlerGuard::Impl636   ~Impl() { ARROW_EXPECT_OK(internal::SetSignalHandler(signum_, old_handler_)); }
637 };
638 
SignalHandlerGuard(int signum,Callback cb)639 SignalHandlerGuard::SignalHandlerGuard(int signum, Callback cb)
640     : SignalHandlerGuard(signum, internal::SignalHandler(cb)) {}
641 
SignalHandlerGuard(int signum,const internal::SignalHandler & handler)642 SignalHandlerGuard::SignalHandlerGuard(int signum, const internal::SignalHandler& handler)
643     : impl_(new Impl{signum, handler}) {}
644 
645 SignalHandlerGuard::~SignalHandlerGuard() = default;
646 
647 namespace {
648 
649 // Used to prevent compiler optimizing away side-effect-less statements
650 volatile int throw_away = 0;
651 
652 }  // namespace
653 
AssertZeroPadded(const Array & array)654 void AssertZeroPadded(const Array& array) {
655   for (const auto& buffer : array.data()->buffers) {
656     if (buffer) {
657       const int64_t padding = buffer->capacity() - buffer->size();
658       if (padding > 0) {
659         std::vector<uint8_t> zeros(padding);
660         ASSERT_EQ(0, memcmp(buffer->data() + buffer->size(), zeros.data(), padding));
661       }
662     }
663   }
664 }
665 
TestInitialized(const Array & array)666 void TestInitialized(const Array& array) { TestInitialized(*array.data()); }
667 
TestInitialized(const ArrayData & array)668 void TestInitialized(const ArrayData& array) {
669   uint8_t total = 0;
670   for (const auto& buffer : array.buffers) {
671     if (buffer && buffer->capacity() > 0) {
672       auto data = buffer->data();
673       for (int64_t i = 0; i < buffer->size(); ++i) {
674         total ^= data[i];
675       }
676     }
677   }
678   uint8_t total_bit = 0;
679   for (uint32_t mask = 1; mask < 256; mask <<= 1) {
680     total_bit ^= (total & mask) != 0;
681   }
682   // This is a dummy condition on all the bits of `total` (which depend on the
683   // entire buffer data).  If not all bits are well-defined, Valgrind will
684   // error with "Conditional jump or move depends on uninitialised value(s)".
685   if (total_bit == 0) {
686     ++throw_away;
687   }
688   for (const auto& child : array.child_data) {
689     TestInitialized(*child);
690   }
691   if (array.dictionary) {
692     TestInitialized(*array.dictionary);
693   }
694 }
695 
SleepFor(double seconds)696 void SleepFor(double seconds) {
697   std::this_thread::sleep_for(
698       std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9)));
699 }
700 
701 #ifdef _WIN32
SleepABit()702 void SleepABit() {
703   LARGE_INTEGER freq, start, now;
704   QueryPerformanceFrequency(&freq);
705   // 1 ms
706   auto desired = freq.QuadPart / 1000;
707   if (desired <= 0) {
708     // Fallback to STL sleep if high resolution clock not available, tests may fail,
709     // shouldn't really happen
710     SleepFor(1e-3);
711     return;
712   }
713   QueryPerformanceCounter(&start);
714   while (true) {
715     std::this_thread::yield();
716     QueryPerformanceCounter(&now);
717     auto elapsed = now.QuadPart - start.QuadPart;
718     if (elapsed > desired) {
719       break;
720     }
721   }
722 }
723 #else
724 // std::this_thread::sleep_for should be high enough resolution on non-Windows systems
SleepABit()725 void SleepABit() { SleepFor(1e-3); }
726 #endif
727 
BusyWait(double seconds,std::function<bool ()> predicate)728 void BusyWait(double seconds, std::function<bool()> predicate) {
729   const double period = 0.001;
730   for (int i = 0; !predicate() && i * period < seconds; ++i) {
731     SleepFor(period);
732   }
733 }
734 
SleepAsync(double seconds)735 Future<> SleepAsync(double seconds) {
736   auto out = Future<>::Make();
737   std::thread([out, seconds]() mutable {
738     SleepFor(seconds);
739     out.MarkFinished();
740   }).detach();
741   return out;
742 }
743 
SleepABitAsync()744 Future<> SleepABitAsync() {
745   auto out = Future<>::Make();
746   std::thread([out]() mutable {
747     SleepABit();
748     out.MarkFinished();
749   }).detach();
750   return out;
751 }
752 
753 ///////////////////////////////////////////////////////////////////////////
754 // Extension types
755 
ExtensionEquals(const ExtensionType & other) const756 bool UuidType::ExtensionEquals(const ExtensionType& other) const {
757   return (other.extension_name() == this->extension_name());
758 }
759 
MakeArray(std::shared_ptr<ArrayData> data) const760 std::shared_ptr<Array> UuidType::MakeArray(std::shared_ptr<ArrayData> data) const {
761   DCHECK_EQ(data->type->id(), Type::EXTENSION);
762   DCHECK_EQ("uuid", static_cast<const ExtensionType&>(*data->type).extension_name());
763   return std::make_shared<UuidArray>(data);
764 }
765 
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const766 Result<std::shared_ptr<DataType>> UuidType::Deserialize(
767     std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
768   if (serialized != "uuid-serialized") {
769     return Status::Invalid("Type identifier did not match: '", serialized, "'");
770   }
771   if (!storage_type->Equals(*fixed_size_binary(16))) {
772     return Status::Invalid("Invalid storage type for UuidType: ",
773                            storage_type->ToString());
774   }
775   return std::make_shared<UuidType>();
776 }
777 
ExtensionEquals(const ExtensionType & other) const778 bool SmallintType::ExtensionEquals(const ExtensionType& other) const {
779   return (other.extension_name() == this->extension_name());
780 }
781 
MakeArray(std::shared_ptr<ArrayData> data) const782 std::shared_ptr<Array> SmallintType::MakeArray(std::shared_ptr<ArrayData> data) const {
783   DCHECK_EQ(data->type->id(), Type::EXTENSION);
784   DCHECK_EQ("smallint", static_cast<const ExtensionType&>(*data->type).extension_name());
785   return std::make_shared<SmallintArray>(data);
786 }
787 
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const788 Result<std::shared_ptr<DataType>> SmallintType::Deserialize(
789     std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
790   if (serialized != "smallint") {
791     return Status::Invalid("Type identifier did not match: '", serialized, "'");
792   }
793   if (!storage_type->Equals(*int16())) {
794     return Status::Invalid("Invalid storage type for SmallintType: ",
795                            storage_type->ToString());
796   }
797   return std::make_shared<SmallintType>();
798 }
799 
ExtensionEquals(const ExtensionType & other) const800 bool DictExtensionType::ExtensionEquals(const ExtensionType& other) const {
801   return (other.extension_name() == this->extension_name());
802 }
803 
MakeArray(std::shared_ptr<ArrayData> data) const804 std::shared_ptr<Array> DictExtensionType::MakeArray(
805     std::shared_ptr<ArrayData> data) const {
806   DCHECK_EQ(data->type->id(), Type::EXTENSION);
807   DCHECK(ExtensionEquals(checked_cast<const ExtensionType&>(*data->type)));
808   // No need for a specific ExtensionArray derived class
809   return std::make_shared<ExtensionArray>(data);
810 }
811 
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const812 Result<std::shared_ptr<DataType>> DictExtensionType::Deserialize(
813     std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
814   if (serialized != "dict-extension-serialized") {
815     return Status::Invalid("Type identifier did not match: '", serialized, "'");
816   }
817   if (!storage_type->Equals(*storage_type_)) {
818     return Status::Invalid("Invalid storage type for DictExtensionType: ",
819                            storage_type->ToString());
820   }
821   return std::make_shared<DictExtensionType>();
822 }
823 
ExtensionEquals(const ExtensionType & other) const824 bool Complex128Type::ExtensionEquals(const ExtensionType& other) const {
825   return (other.extension_name() == this->extension_name());
826 }
827 
MakeArray(std::shared_ptr<ArrayData> data) const828 std::shared_ptr<Array> Complex128Type::MakeArray(std::shared_ptr<ArrayData> data) const {
829   DCHECK_EQ(data->type->id(), Type::EXTENSION);
830   DCHECK(ExtensionEquals(checked_cast<const ExtensionType&>(*data->type)));
831   return std::make_shared<Complex128Array>(data);
832 }
833 
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const834 Result<std::shared_ptr<DataType>> Complex128Type::Deserialize(
835     std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
836   if (serialized != "complex128-serialized") {
837     return Status::Invalid("Type identifier did not match: '", serialized, "'");
838   }
839   if (!storage_type->Equals(*storage_type_)) {
840     return Status::Invalid("Invalid storage type for Complex128Type: ",
841                            storage_type->ToString());
842   }
843   return std::make_shared<Complex128Type>();
844 }
845 
uuid()846 std::shared_ptr<DataType> uuid() { return std::make_shared<UuidType>(); }
847 
smallint()848 std::shared_ptr<DataType> smallint() { return std::make_shared<SmallintType>(); }
849 
dict_extension_type()850 std::shared_ptr<DataType> dict_extension_type() {
851   return std::make_shared<DictExtensionType>();
852 }
853 
complex128()854 std::shared_ptr<DataType> complex128() { return std::make_shared<Complex128Type>(); }
855 
MakeComplex128(const std::shared_ptr<Array> & real,const std::shared_ptr<Array> & imag)856 std::shared_ptr<Array> MakeComplex128(const std::shared_ptr<Array>& real,
857                                       const std::shared_ptr<Array>& imag) {
858   auto type = complex128();
859   std::shared_ptr<Array> storage(
860       new StructArray(checked_cast<const ExtensionType&>(*type).storage_type(),
861                       real->length(), {real, imag}));
862   return ExtensionType::WrapArray(type, storage);
863 }
864 
ExampleUuid()865 std::shared_ptr<Array> ExampleUuid() {
866   auto arr = ArrayFromJSON(
867       fixed_size_binary(16),
868       "[null, \"abcdefghijklmno0\", \"abcdefghijklmno1\", \"abcdefghijklmno2\"]");
869   return ExtensionType::WrapArray(uuid(), arr);
870 }
871 
ExampleSmallint()872 std::shared_ptr<Array> ExampleSmallint() {
873   auto arr = ArrayFromJSON(int16(), "[-32768, null, 1, 2, 3, 4, 32767]");
874   return ExtensionType::WrapArray(smallint(), arr);
875 }
876 
ExampleDictExtension()877 std::shared_ptr<Array> ExampleDictExtension() {
878   auto arr = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, null, 1]",
879                                R"(["foo", "bar"])");
880   return ExtensionType::WrapArray(dict_extension_type(), arr);
881 }
882 
ExampleComplex128()883 std::shared_ptr<Array> ExampleComplex128() {
884   auto arr = ArrayFromJSON(struct_({field("", float64()), field("", float64())}),
885                            "[[1.0, -2.5], null, [3.0, -4.5]]");
886   return ExtensionType::WrapArray(complex128(), arr);
887 }
888 
ExtensionTypeGuard(const std::shared_ptr<DataType> & type)889 ExtensionTypeGuard::ExtensionTypeGuard(const std::shared_ptr<DataType>& type)
890     : ExtensionTypeGuard(DataTypeVector{type}) {}
891 
ExtensionTypeGuard(const DataTypeVector & types)892 ExtensionTypeGuard::ExtensionTypeGuard(const DataTypeVector& types) {
893   for (const auto& type : types) {
894     ARROW_CHECK_EQ(type->id(), Type::EXTENSION);
895     auto ext_type = checked_pointer_cast<ExtensionType>(type);
896 
897     ARROW_CHECK_OK(RegisterExtensionType(ext_type));
898     extension_names_.push_back(ext_type->extension_name());
899     DCHECK(!extension_names_.back().empty());
900   }
901 }
902 
~ExtensionTypeGuard()903 ExtensionTypeGuard::~ExtensionTypeGuard() {
904   for (const auto& name : extension_names_) {
905     ARROW_CHECK_OK(UnregisterExtensionType(name));
906   }
907 }
908 
909 class GatingTask::Impl : public std::enable_shared_from_this<GatingTask::Impl> {
910  public:
Impl(double timeout_seconds)911   explicit Impl(double timeout_seconds)
912       : timeout_seconds_(timeout_seconds), status_(), unlocked_(false) {
913     unlocked_future_ = Future<>::Make();
914   }
915 
~Impl()916   ~Impl() {
917     if (num_running_ != num_launched_) {
918       ADD_FAILURE()
919           << "A GatingTask instance was destroyed but some underlying tasks did not "
920              "start running"
921           << std::endl;
922     } else if (num_finished_ != num_launched_) {
923       ADD_FAILURE()
924           << "A GatingTask instance was destroyed but some underlying tasks did not "
925              "finish running"
926           << std::endl;
927     }
928   }
929 
Task()930   std::function<void()> Task() {
931     num_launched_++;
932     auto self = shared_from_this();
933     return [self] { self->RunTask(); };
934   }
935 
AsyncTask()936   Future<> AsyncTask() {
937     num_launched_++;
938     num_running_++;
939     /// TODO(ARROW-13004) Could maybe implement this check with future chains
940     /// if we check to see if the future has been "consumed" or not
941     num_finished_++;
942     return unlocked_future_;
943   }
944 
RunTask()945   void RunTask() {
946     std::unique_lock<std::mutex> lk(mx_);
947     num_running_++;
948     running_cv_.notify_all();
949     if (!unlocked_cv_.wait_for(
950             lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
951             [this] { return unlocked_; })) {
952       status_ &= Status::Invalid("Timed out (" + std::to_string(timeout_seconds_) + "," +
953                                  std::to_string(unlocked_) +
954                                  " seconds) waiting for the gating task to be unlocked");
955     }
956     num_finished_++;
957   }
958 
WaitForRunning(int count)959   Status WaitForRunning(int count) {
960     std::unique_lock<std::mutex> lk(mx_);
961     if (running_cv_.wait_for(
962             lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
963             [this, count] { return num_running_ >= count; })) {
964       return Status::OK();
965     }
966     return Status::Invalid("Timed out waiting for tasks to launch");
967   }
968 
Unlock()969   Status Unlock() {
970     std::lock_guard<std::mutex> lk(mx_);
971     unlocked_ = true;
972     unlocked_cv_.notify_all();
973     unlocked_future_.MarkFinished();
974     return status_;
975   }
976 
977  private:
978   double timeout_seconds_;
979   Status status_;
980   bool unlocked_;
981   std::atomic<int> num_launched_{0};
982   int num_running_ = 0;
983   int num_finished_ = 0;
984   std::mutex mx_;
985   std::condition_variable running_cv_;
986   std::condition_variable unlocked_cv_;
987   Future<> unlocked_future_;
988 };
989 
GatingTask(double timeout_seconds)990 GatingTask::GatingTask(double timeout_seconds) : impl_(new Impl(timeout_seconds)) {}
991 
~GatingTask()992 GatingTask::~GatingTask() {}
993 
Task()994 std::function<void()> GatingTask::Task() { return impl_->Task(); }
995 
AsyncTask()996 Future<> GatingTask::AsyncTask() { return impl_->AsyncTask(); }
997 
Unlock()998 Status GatingTask::Unlock() { return impl_->Unlock(); }
999 
WaitForRunning(int count)1000 Status GatingTask::WaitForRunning(int count) { return impl_->WaitForRunning(count); }
1001 
Make(double timeout_seconds)1002 std::shared_ptr<GatingTask> GatingTask::Make(double timeout_seconds) {
1003   return std::make_shared<GatingTask>(timeout_seconds);
1004 }
1005 
1006 }  // namespace arrow
1007