1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/testing/gtest_util.h"
19
20 #include "arrow/testing/extension_type.h"
21
22 #ifndef _WIN32
23 #include <sys/stat.h> // IWYU pragma: keep
24 #include <sys/wait.h> // IWYU pragma: keep
25 #include <unistd.h> // IWYU pragma: keep
26 #endif
27
28 #include <algorithm>
29 #include <chrono>
30 #include <condition_variable>
31 #include <cstdint>
32 #include <cstdlib>
33 #include <iostream>
34 #include <limits>
35 #include <locale>
36 #include <memory>
37 #include <mutex>
38 #include <sstream>
39 #include <stdexcept>
40 #include <string>
41 #include <thread>
42 #include <vector>
43
44 #include "arrow/array.h"
45 #include "arrow/buffer.h"
46 #include "arrow/datum.h"
47 #include "arrow/ipc/json_simple.h"
48 #include "arrow/pretty_print.h"
49 #include "arrow/status.h"
50 #include "arrow/table.h"
51 #include "arrow/type.h"
52 #include "arrow/util/checked_cast.h"
53 #include "arrow/util/future.h"
54 #include "arrow/util/io_util.h"
55 #include "arrow/util/logging.h"
56 #include "arrow/util/windows_compatibility.h"
57
58 namespace arrow {
59
60 using internal::checked_cast;
61 using internal::checked_pointer_cast;
62
AllTypeIds()63 std::vector<Type::type> AllTypeIds() {
64 return {Type::NA,
65 Type::BOOL,
66 Type::INT8,
67 Type::INT16,
68 Type::INT32,
69 Type::INT64,
70 Type::UINT8,
71 Type::UINT16,
72 Type::UINT32,
73 Type::UINT64,
74 Type::HALF_FLOAT,
75 Type::FLOAT,
76 Type::DOUBLE,
77 Type::DECIMAL128,
78 Type::DECIMAL256,
79 Type::DATE32,
80 Type::DATE64,
81 Type::TIME32,
82 Type::TIME64,
83 Type::TIMESTAMP,
84 Type::INTERVAL_DAY_TIME,
85 Type::INTERVAL_MONTHS,
86 Type::DURATION,
87 Type::STRING,
88 Type::BINARY,
89 Type::LARGE_STRING,
90 Type::LARGE_BINARY,
91 Type::FIXED_SIZE_BINARY,
92 Type::STRUCT,
93 Type::LIST,
94 Type::LARGE_LIST,
95 Type::FIXED_SIZE_LIST,
96 Type::MAP,
97 Type::DENSE_UNION,
98 Type::SPARSE_UNION,
99 Type::DICTIONARY,
100 Type::EXTENSION,
101 Type::INTERVAL_MONTH_DAY_NANO};
102 }
103
104 template <typename T, typename CompareFunctor>
AssertTsSame(const T & expected,const T & actual,CompareFunctor && compare)105 void AssertTsSame(const T& expected, const T& actual, CompareFunctor&& compare) {
106 if (!compare(actual, expected)) {
107 std::stringstream pp_expected;
108 std::stringstream pp_actual;
109 ::arrow::PrettyPrintOptions options(/*indent=*/2);
110 options.window = 50;
111 ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
112 ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
113 FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
114 }
115 }
116
117 template <typename CompareFunctor>
AssertArraysEqualWith(const Array & expected,const Array & actual,bool verbose,CompareFunctor && compare)118 void AssertArraysEqualWith(const Array& expected, const Array& actual, bool verbose,
119 CompareFunctor&& compare) {
120 std::stringstream diff;
121 if (!compare(expected, actual, &diff)) {
122 if (expected.data()->null_count != actual.data()->null_count) {
123 diff << "Null counts differ. Expected " << expected.data()->null_count
124 << " but was " << actual.data()->null_count << "\n";
125 }
126 if (verbose) {
127 ::arrow::PrettyPrintOptions options(/*indent=*/2);
128 options.window = 50;
129 diff << "Expected:\n";
130 ARROW_EXPECT_OK(PrettyPrint(expected, options, &diff));
131 diff << "\nActual:\n";
132 ARROW_EXPECT_OK(PrettyPrint(actual, options, &diff));
133 }
134 FAIL() << diff.str();
135 }
136 }
137
AssertArraysEqual(const Array & expected,const Array & actual,bool verbose,const EqualOptions & options)138 void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose,
139 const EqualOptions& options) {
140 return AssertArraysEqualWith(
141 expected, actual, verbose,
142 [&](const Array& expected, const Array& actual, std::stringstream* diff) {
143 return expected.Equals(actual, options.diff_sink(diff));
144 });
145 }
146
AssertArraysApproxEqual(const Array & expected,const Array & actual,bool verbose,const EqualOptions & options)147 void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose,
148 const EqualOptions& options) {
149 return AssertArraysEqualWith(
150 expected, actual, verbose,
151 [&](const Array& expected, const Array& actual, std::stringstream* diff) {
152 return expected.ApproxEquals(actual, options.diff_sink(diff));
153 });
154 }
155
AssertScalarsEqual(const Scalar & expected,const Scalar & actual,bool verbose,const EqualOptions & options)156 void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose,
157 const EqualOptions& options) {
158 if (!expected.Equals(actual, options)) {
159 std::stringstream diff;
160 if (verbose) {
161 diff << "Expected:\n" << expected.ToString();
162 diff << "\nActual:\n" << actual.ToString();
163 }
164 FAIL() << diff.str();
165 }
166 }
167
AssertScalarsApproxEqual(const Scalar & expected,const Scalar & actual,bool verbose,const EqualOptions & options)168 void AssertScalarsApproxEqual(const Scalar& expected, const Scalar& actual, bool verbose,
169 const EqualOptions& options) {
170 if (!expected.ApproxEquals(actual, options)) {
171 std::stringstream diff;
172 if (verbose) {
173 diff << "Expected:\n" << expected.ToString();
174 diff << "\nActual:\n" << actual.ToString();
175 }
176 FAIL() << diff.str();
177 }
178 }
179
AssertBatchesEqual(const RecordBatch & expected,const RecordBatch & actual,bool check_metadata)180 void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual,
181 bool check_metadata) {
182 AssertTsSame(expected, actual,
183 [&](const RecordBatch& expected, const RecordBatch& actual) {
184 return expected.Equals(actual, check_metadata);
185 });
186 }
187
AssertBatchesApproxEqual(const RecordBatch & expected,const RecordBatch & actual)188 void AssertBatchesApproxEqual(const RecordBatch& expected, const RecordBatch& actual) {
189 AssertTsSame(expected, actual,
190 [&](const RecordBatch& expected, const RecordBatch& actual) {
191 return expected.ApproxEquals(actual);
192 });
193 }
194
AssertChunkedEqual(const ChunkedArray & expected,const ChunkedArray & actual)195 void AssertChunkedEqual(const ChunkedArray& expected, const ChunkedArray& actual) {
196 ASSERT_EQ(expected.num_chunks(), actual.num_chunks()) << "# chunks unequal";
197 if (!actual.Equals(expected)) {
198 std::stringstream diff;
199 for (int i = 0; i < actual.num_chunks(); ++i) {
200 auto c1 = actual.chunk(i);
201 auto c2 = expected.chunk(i);
202 diff << "# chunk " << i << std::endl;
203 ARROW_IGNORE_EXPR(c1->Equals(c2, EqualOptions().diff_sink(&diff)));
204 }
205 FAIL() << diff.str();
206 }
207 }
208
AssertChunkedEqual(const ChunkedArray & actual,const ArrayVector & expected)209 void AssertChunkedEqual(const ChunkedArray& actual, const ArrayVector& expected) {
210 AssertChunkedEqual(ChunkedArray(expected, actual.type()), actual);
211 }
212
AssertChunkedEquivalent(const ChunkedArray & expected,const ChunkedArray & actual)213 void AssertChunkedEquivalent(const ChunkedArray& expected, const ChunkedArray& actual) {
214 // XXX: AssertChunkedEqual in gtest_util.h does not permit the chunk layouts
215 // to be different
216 if (!actual.Equals(expected)) {
217 std::stringstream pp_expected;
218 std::stringstream pp_actual;
219 ::arrow::PrettyPrintOptions options(/*indent=*/2);
220 options.window = 50;
221 ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
222 ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
223 FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
224 }
225 }
226
AssertChunkedApproxEquivalent(const ChunkedArray & expected,const ChunkedArray & actual,const EqualOptions & equal_options)227 void AssertChunkedApproxEquivalent(const ChunkedArray& expected,
228 const ChunkedArray& actual,
229 const EqualOptions& equal_options) {
230 if (!actual.ApproxEquals(expected, equal_options)) {
231 std::stringstream pp_expected;
232 std::stringstream pp_actual;
233 ::arrow::PrettyPrintOptions options(/*indent=*/2);
234 options.window = 50;
235 ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
236 ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
237 FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
238 }
239 }
240
AssertBufferEqual(const Buffer & buffer,const std::vector<uint8_t> & expected)241 void AssertBufferEqual(const Buffer& buffer, const std::vector<uint8_t>& expected) {
242 ASSERT_EQ(static_cast<size_t>(buffer.size()), expected.size())
243 << "Mismatching buffer size";
244 const uint8_t* buffer_data = buffer.data();
245 for (size_t i = 0; i < expected.size(); ++i) {
246 ASSERT_EQ(buffer_data[i], expected[i]);
247 }
248 }
249
AssertBufferEqual(const Buffer & buffer,const std::string & expected)250 void AssertBufferEqual(const Buffer& buffer, const std::string& expected) {
251 ASSERT_EQ(static_cast<size_t>(buffer.size()), expected.length())
252 << "Mismatching buffer size";
253 const uint8_t* buffer_data = buffer.data();
254 for (size_t i = 0; i < expected.size(); ++i) {
255 ASSERT_EQ(buffer_data[i], expected[i]);
256 }
257 }
258
AssertBufferEqual(const Buffer & buffer,const Buffer & expected)259 void AssertBufferEqual(const Buffer& buffer, const Buffer& expected) {
260 ASSERT_EQ(buffer.size(), expected.size()) << "Mismatching buffer size";
261 ASSERT_TRUE(buffer.Equals(expected));
262 }
263
264 template <typename T>
ToStringWithMetadata(const T & t,bool show_metadata)265 std::string ToStringWithMetadata(const T& t, bool show_metadata) {
266 return t.ToString(show_metadata);
267 }
268
ToStringWithMetadata(const DataType & t,bool show_metadata)269 std::string ToStringWithMetadata(const DataType& t, bool show_metadata) {
270 return t.ToString();
271 }
272
273 template <typename T>
AssertFingerprintablesEqual(const T & left,const T & right,bool check_metadata,const char * types_plural)274 void AssertFingerprintablesEqual(const T& left, const T& right, bool check_metadata,
275 const char* types_plural) {
276 ASSERT_TRUE(left.Equals(right, check_metadata))
277 << types_plural << " '" << ToStringWithMetadata(left, check_metadata) << "' and '"
278 << ToStringWithMetadata(right, check_metadata) << "' should have compared equal";
279 auto lfp = left.fingerprint();
280 auto rfp = right.fingerprint();
281 // Note: all types tested in this file should implement fingerprinting,
282 // except extension types.
283 if (check_metadata) {
284 lfp += left.metadata_fingerprint();
285 rfp += right.metadata_fingerprint();
286 }
287 ASSERT_EQ(lfp, rfp) << "Fingerprints for " << types_plural << " '"
288 << ToStringWithMetadata(left, check_metadata) << "' and '"
289 << ToStringWithMetadata(right, check_metadata)
290 << "' should have compared equal";
291 }
292
293 template <typename T>
AssertFingerprintablesEqual(const std::shared_ptr<T> & left,const std::shared_ptr<T> & right,bool check_metadata,const char * types_plural)294 void AssertFingerprintablesEqual(const std::shared_ptr<T>& left,
295 const std::shared_ptr<T>& right, bool check_metadata,
296 const char* types_plural) {
297 ASSERT_NE(left, nullptr);
298 ASSERT_NE(right, nullptr);
299 AssertFingerprintablesEqual(*left, *right, check_metadata, types_plural);
300 }
301
302 template <typename T>
AssertFingerprintablesNotEqual(const T & left,const T & right,bool check_metadata,const char * types_plural)303 void AssertFingerprintablesNotEqual(const T& left, const T& right, bool check_metadata,
304 const char* types_plural) {
305 ASSERT_FALSE(left.Equals(right, check_metadata))
306 << types_plural << " '" << ToStringWithMetadata(left, check_metadata) << "' and '"
307 << ToStringWithMetadata(right, check_metadata) << "' should have compared unequal";
308 auto lfp = left.fingerprint();
309 auto rfp = right.fingerprint();
310 // Note: all types tested in this file should implement fingerprinting,
311 // except extension types.
312 if (lfp != "" && rfp != "") {
313 if (check_metadata) {
314 lfp += left.metadata_fingerprint();
315 rfp += right.metadata_fingerprint();
316 }
317 ASSERT_NE(lfp, rfp) << "Fingerprints for " << types_plural << " '"
318 << ToStringWithMetadata(left, check_metadata) << "' and '"
319 << ToStringWithMetadata(right, check_metadata)
320 << "' should have compared unequal";
321 }
322 }
323
324 template <typename T>
AssertFingerprintablesNotEqual(const std::shared_ptr<T> & left,const std::shared_ptr<T> & right,bool check_metadata,const char * types_plural)325 void AssertFingerprintablesNotEqual(const std::shared_ptr<T>& left,
326 const std::shared_ptr<T>& right, bool check_metadata,
327 const char* types_plural) {
328 ASSERT_NE(left, nullptr);
329 ASSERT_NE(right, nullptr);
330 AssertFingerprintablesNotEqual(*left, *right, check_metadata, types_plural);
331 }
332
333 #define ASSERT_EQUAL_IMPL(NAME, TYPE, PLURAL) \
334 void Assert##NAME##Equal(const TYPE& left, const TYPE& right, bool check_metadata) { \
335 AssertFingerprintablesEqual(left, right, check_metadata, PLURAL); \
336 } \
337 \
338 void Assert##NAME##Equal(const std::shared_ptr<TYPE>& left, \
339 const std::shared_ptr<TYPE>& right, bool check_metadata) { \
340 AssertFingerprintablesEqual(left, right, check_metadata, PLURAL); \
341 } \
342 \
343 void Assert##NAME##NotEqual(const TYPE& left, const TYPE& right, \
344 bool check_metadata) { \
345 AssertFingerprintablesNotEqual(left, right, check_metadata, PLURAL); \
346 } \
347 void Assert##NAME##NotEqual(const std::shared_ptr<TYPE>& left, \
348 const std::shared_ptr<TYPE>& right, bool check_metadata) { \
349 AssertFingerprintablesNotEqual(left, right, check_metadata, PLURAL); \
350 }
351
352 ASSERT_EQUAL_IMPL(Type, DataType, "types")
353 ASSERT_EQUAL_IMPL(Field, Field, "fields")
354 ASSERT_EQUAL_IMPL(Schema, Schema, "schemas")
355 #undef ASSERT_EQUAL_IMPL
356
AssertDatumsEqual(const Datum & expected,const Datum & actual,bool verbose)357 void AssertDatumsEqual(const Datum& expected, const Datum& actual, bool verbose) {
358 ASSERT_EQ(expected.kind(), actual.kind())
359 << "expected:" << expected.ToString() << " got:" << actual.ToString();
360
361 switch (expected.kind()) {
362 case Datum::SCALAR:
363 AssertScalarsEqual(*expected.scalar(), *actual.scalar(), verbose);
364 break;
365 case Datum::ARRAY: {
366 auto expected_array = expected.make_array();
367 auto actual_array = actual.make_array();
368 AssertArraysEqual(*expected_array, *actual_array, verbose);
369 } break;
370 case Datum::CHUNKED_ARRAY:
371 AssertChunkedEquivalent(*expected.chunked_array(), *actual.chunked_array());
372 break;
373 default:
374 // TODO: Implement better print
375 ASSERT_TRUE(actual.Equals(expected));
376 break;
377 }
378 }
379
AssertDatumsApproxEqual(const Datum & expected,const Datum & actual,bool verbose,const EqualOptions & options)380 void AssertDatumsApproxEqual(const Datum& expected, const Datum& actual, bool verbose,
381 const EqualOptions& options) {
382 ASSERT_EQ(expected.kind(), actual.kind())
383 << "expected:" << expected.ToString() << " got:" << actual.ToString();
384
385 switch (expected.kind()) {
386 case Datum::SCALAR:
387 AssertScalarsApproxEqual(*expected.scalar(), *actual.scalar(), verbose, options);
388 break;
389 case Datum::ARRAY: {
390 auto expected_array = expected.make_array();
391 auto actual_array = actual.make_array();
392 AssertArraysApproxEqual(*expected_array, *actual_array, verbose, options);
393 break;
394 }
395 case Datum::CHUNKED_ARRAY: {
396 auto expected_array = expected.chunked_array();
397 auto actual_array = actual.chunked_array();
398 AssertChunkedApproxEquivalent(*expected_array, *actual_array, options);
399 break;
400 }
401 default:
402 // TODO: Implement better print
403 ASSERT_TRUE(actual.Equals(expected));
404 break;
405 }
406 }
407
ArrayFromJSON(const std::shared_ptr<DataType> & type,util::string_view json)408 std::shared_ptr<Array> ArrayFromJSON(const std::shared_ptr<DataType>& type,
409 util::string_view json) {
410 std::shared_ptr<Array> out;
411 ABORT_NOT_OK(ipc::internal::json::ArrayFromJSON(type, json, &out));
412 return out;
413 }
414
DictArrayFromJSON(const std::shared_ptr<DataType> & type,util::string_view indices_json,util::string_view dictionary_json)415 std::shared_ptr<Array> DictArrayFromJSON(const std::shared_ptr<DataType>& type,
416 util::string_view indices_json,
417 util::string_view dictionary_json) {
418 std::shared_ptr<Array> out;
419 ABORT_NOT_OK(
420 ipc::internal::json::DictArrayFromJSON(type, indices_json, dictionary_json, &out));
421 return out;
422 }
423
ChunkedArrayFromJSON(const std::shared_ptr<DataType> & type,const std::vector<std::string> & json)424 std::shared_ptr<ChunkedArray> ChunkedArrayFromJSON(const std::shared_ptr<DataType>& type,
425 const std::vector<std::string>& json) {
426 ArrayVector out_chunks;
427 for (const std::string& chunk_json : json) {
428 out_chunks.push_back(ArrayFromJSON(type, chunk_json));
429 }
430 return std::make_shared<ChunkedArray>(std::move(out_chunks), type);
431 }
432
RecordBatchFromJSON(const std::shared_ptr<Schema> & schema,util::string_view json)433 std::shared_ptr<RecordBatch> RecordBatchFromJSON(const std::shared_ptr<Schema>& schema,
434 util::string_view json) {
435 // Parse as a StructArray
436 auto struct_type = struct_(schema->fields());
437 std::shared_ptr<Array> struct_array = ArrayFromJSON(struct_type, json);
438
439 // Convert StructArray to RecordBatch
440 return *RecordBatch::FromStructArray(struct_array);
441 }
442
ScalarFromJSON(const std::shared_ptr<DataType> & type,util::string_view json)443 std::shared_ptr<Scalar> ScalarFromJSON(const std::shared_ptr<DataType>& type,
444 util::string_view json) {
445 std::shared_ptr<Scalar> out;
446 ABORT_NOT_OK(ipc::internal::json::ScalarFromJSON(type, json, &out));
447 return out;
448 }
449
DictScalarFromJSON(const std::shared_ptr<DataType> & type,util::string_view index_json,util::string_view dictionary_json)450 std::shared_ptr<Scalar> DictScalarFromJSON(const std::shared_ptr<DataType>& type,
451 util::string_view index_json,
452 util::string_view dictionary_json) {
453 std::shared_ptr<Scalar> out;
454 ABORT_NOT_OK(
455 ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out));
456 return out;
457 }
458
TableFromJSON(const std::shared_ptr<Schema> & schema,const std::vector<std::string> & json)459 std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>& schema,
460 const std::vector<std::string>& json) {
461 std::vector<std::shared_ptr<RecordBatch>> batches;
462 for (const std::string& batch_json : json) {
463 batches.push_back(RecordBatchFromJSON(schema, batch_json));
464 }
465 return *Table::FromRecordBatches(schema, std::move(batches));
466 }
467
PrintArrayDiff(const ChunkedArray & expected,const ChunkedArray & actual)468 Result<util::optional<std::string>> PrintArrayDiff(const ChunkedArray& expected,
469 const ChunkedArray& actual) {
470 if (actual.Equals(expected)) {
471 return util::nullopt;
472 }
473
474 std::stringstream ss;
475 if (expected.length() != actual.length()) {
476 ss << "Expected length " << expected.length() << " but was actually "
477 << actual.length();
478 return ss.str();
479 }
480
481 PrettyPrintOptions options(/*indent=*/2);
482 options.window = 50;
483 RETURN_NOT_OK(internal::ApplyBinaryChunked(
484 actual, expected,
485 [&](const Array& left_piece, const Array& right_piece, int64_t position) {
486 std::stringstream diff;
487 if (!left_piece.Equals(right_piece, EqualOptions().diff_sink(&diff))) {
488 ss << "Unequal at absolute position " << position << "\n" << diff.str();
489 ss << "Expected:\n";
490 ARROW_EXPECT_OK(PrettyPrint(right_piece, options, &ss));
491 ss << "\nActual:\n";
492 ARROW_EXPECT_OK(PrettyPrint(left_piece, options, &ss));
493 }
494 return Status::OK();
495 }));
496 return ss.str();
497 }
498
AssertTablesEqual(const Table & expected,const Table & actual,bool same_chunk_layout,bool combine_chunks)499 void AssertTablesEqual(const Table& expected, const Table& actual, bool same_chunk_layout,
500 bool combine_chunks) {
501 ASSERT_EQ(expected.num_columns(), actual.num_columns());
502
503 if (combine_chunks) {
504 auto pool = default_memory_pool();
505 ASSERT_OK_AND_ASSIGN(auto new_expected, expected.CombineChunks(pool));
506 ASSERT_OK_AND_ASSIGN(auto new_actual, actual.CombineChunks(pool));
507
508 AssertTablesEqual(*new_expected, *new_actual, false, false);
509 return;
510 }
511
512 if (same_chunk_layout) {
513 for (int i = 0; i < actual.num_columns(); ++i) {
514 AssertChunkedEqual(*expected.column(i), *actual.column(i));
515 }
516 } else {
517 std::stringstream ss;
518 for (int i = 0; i < actual.num_columns(); ++i) {
519 auto actual_col = actual.column(i);
520 auto expected_col = expected.column(i);
521
522 ASSERT_OK_AND_ASSIGN(auto diff, PrintArrayDiff(*expected_col, *actual_col));
523 if (diff.has_value()) {
524 FAIL() << *diff;
525 }
526 }
527 }
528 }
529
530 template <typename CompareFunctor>
CompareBatchWith(const RecordBatch & left,const RecordBatch & right,bool compare_metadata,CompareFunctor && compare)531 void CompareBatchWith(const RecordBatch& left, const RecordBatch& right,
532 bool compare_metadata, CompareFunctor&& compare) {
533 if (!left.schema()->Equals(*right.schema(), compare_metadata)) {
534 FAIL() << "Left schema: " << left.schema()->ToString(compare_metadata)
535 << "\nRight schema: " << right.schema()->ToString(compare_metadata);
536 }
537 ASSERT_EQ(left.num_columns(), right.num_columns())
538 << left.schema()->ToString() << " result: " << right.schema()->ToString();
539 ASSERT_EQ(left.num_rows(), right.num_rows());
540 for (int i = 0; i < left.num_columns(); ++i) {
541 if (!compare(*left.column(i), *right.column(i))) {
542 std::stringstream ss;
543 ss << "Idx: " << i << " Name: " << left.column_name(i);
544 ss << std::endl << "Left: ";
545 ASSERT_OK(PrettyPrint(*left.column(i), 0, &ss));
546 ss << std::endl << "Right: ";
547 ASSERT_OK(PrettyPrint(*right.column(i), 0, &ss));
548 FAIL() << ss.str();
549 }
550 }
551 }
552
CompareBatch(const RecordBatch & left,const RecordBatch & right,bool compare_metadata)553 void CompareBatch(const RecordBatch& left, const RecordBatch& right,
554 bool compare_metadata) {
555 return CompareBatchWith(
556 left, right, compare_metadata,
557 [](const Array& left, const Array& right) { return left.Equals(right); });
558 }
559
ApproxCompareBatch(const RecordBatch & left,const RecordBatch & right,bool compare_metadata)560 void ApproxCompareBatch(const RecordBatch& left, const RecordBatch& right,
561 bool compare_metadata) {
562 return CompareBatchWith(
563 left, right, compare_metadata,
564 [](const Array& left, const Array& right) { return left.ApproxEquals(right); });
565 }
566
TweakValidityBit(const std::shared_ptr<Array> & array,int64_t index,bool validity)567 std::shared_ptr<Array> TweakValidityBit(const std::shared_ptr<Array>& array,
568 int64_t index, bool validity) {
569 auto data = array->data()->Copy();
570 if (data->buffers[0] == nullptr) {
571 data->buffers[0] = *AllocateBitmap(data->length);
572 BitUtil::SetBitsTo(data->buffers[0]->mutable_data(), 0, data->length, true);
573 }
574 BitUtil::SetBitTo(data->buffers[0]->mutable_data(), index, validity);
575 data->null_count = kUnknownNullCount;
576 // Need to return a new array, because Array caches the null bitmap pointer
577 return MakeArray(data);
578 }
579
LocaleExists(const char * locale)580 bool LocaleExists(const char* locale) {
581 try {
582 std::locale loc(locale);
583 return true;
584 } catch (std::runtime_error&) {
585 return false;
586 }
587 }
588
589 class LocaleGuard::Impl {
590 public:
Impl(const char * new_locale)591 explicit Impl(const char* new_locale) : global_locale_(std::locale()) {
592 try {
593 std::locale::global(std::locale(new_locale));
594 } catch (std::runtime_error&) {
595 ARROW_LOG(WARNING) << "Locale unavailable (ignored): '" << new_locale << "'";
596 }
597 }
598
~Impl()599 ~Impl() { std::locale::global(global_locale_); }
600
601 protected:
602 std::locale global_locale_;
603 };
604
LocaleGuard(const char * new_locale)605 LocaleGuard::LocaleGuard(const char* new_locale) : impl_(new Impl(new_locale)) {}
606
~LocaleGuard()607 LocaleGuard::~LocaleGuard() {}
608
EnvVarGuard(const std::string & name,const std::string & value)609 EnvVarGuard::EnvVarGuard(const std::string& name, const std::string& value)
610 : name_(name) {
611 auto maybe_value = arrow::internal::GetEnvVar(name);
612 if (maybe_value.ok()) {
613 was_set_ = true;
614 old_value_ = *std::move(maybe_value);
615 } else {
616 was_set_ = false;
617 }
618 ARROW_CHECK_OK(arrow::internal::SetEnvVar(name, value));
619 }
620
~EnvVarGuard()621 EnvVarGuard::~EnvVarGuard() {
622 if (was_set_) {
623 ARROW_CHECK_OK(arrow::internal::SetEnvVar(name_, old_value_));
624 } else {
625 ARROW_CHECK_OK(arrow::internal::DelEnvVar(name_));
626 }
627 }
628
629 struct SignalHandlerGuard::Impl {
630 int signum_;
631 internal::SignalHandler old_handler_;
632
Implarrow::SignalHandlerGuard::Impl633 Impl(int signum, const internal::SignalHandler& handler)
634 : signum_(signum), old_handler_(*internal::SetSignalHandler(signum, handler)) {}
635
~Implarrow::SignalHandlerGuard::Impl636 ~Impl() { ARROW_EXPECT_OK(internal::SetSignalHandler(signum_, old_handler_)); }
637 };
638
SignalHandlerGuard(int signum,Callback cb)639 SignalHandlerGuard::SignalHandlerGuard(int signum, Callback cb)
640 : SignalHandlerGuard(signum, internal::SignalHandler(cb)) {}
641
SignalHandlerGuard(int signum,const internal::SignalHandler & handler)642 SignalHandlerGuard::SignalHandlerGuard(int signum, const internal::SignalHandler& handler)
643 : impl_(new Impl{signum, handler}) {}
644
645 SignalHandlerGuard::~SignalHandlerGuard() = default;
646
647 namespace {
648
649 // Used to prevent compiler optimizing away side-effect-less statements
650 volatile int throw_away = 0;
651
652 } // namespace
653
AssertZeroPadded(const Array & array)654 void AssertZeroPadded(const Array& array) {
655 for (const auto& buffer : array.data()->buffers) {
656 if (buffer) {
657 const int64_t padding = buffer->capacity() - buffer->size();
658 if (padding > 0) {
659 std::vector<uint8_t> zeros(padding);
660 ASSERT_EQ(0, memcmp(buffer->data() + buffer->size(), zeros.data(), padding));
661 }
662 }
663 }
664 }
665
TestInitialized(const Array & array)666 void TestInitialized(const Array& array) { TestInitialized(*array.data()); }
667
TestInitialized(const ArrayData & array)668 void TestInitialized(const ArrayData& array) {
669 uint8_t total = 0;
670 for (const auto& buffer : array.buffers) {
671 if (buffer && buffer->capacity() > 0) {
672 auto data = buffer->data();
673 for (int64_t i = 0; i < buffer->size(); ++i) {
674 total ^= data[i];
675 }
676 }
677 }
678 uint8_t total_bit = 0;
679 for (uint32_t mask = 1; mask < 256; mask <<= 1) {
680 total_bit ^= (total & mask) != 0;
681 }
682 // This is a dummy condition on all the bits of `total` (which depend on the
683 // entire buffer data). If not all bits are well-defined, Valgrind will
684 // error with "Conditional jump or move depends on uninitialised value(s)".
685 if (total_bit == 0) {
686 ++throw_away;
687 }
688 for (const auto& child : array.child_data) {
689 TestInitialized(*child);
690 }
691 if (array.dictionary) {
692 TestInitialized(*array.dictionary);
693 }
694 }
695
SleepFor(double seconds)696 void SleepFor(double seconds) {
697 std::this_thread::sleep_for(
698 std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9)));
699 }
700
701 #ifdef _WIN32
SleepABit()702 void SleepABit() {
703 LARGE_INTEGER freq, start, now;
704 QueryPerformanceFrequency(&freq);
705 // 1 ms
706 auto desired = freq.QuadPart / 1000;
707 if (desired <= 0) {
708 // Fallback to STL sleep if high resolution clock not available, tests may fail,
709 // shouldn't really happen
710 SleepFor(1e-3);
711 return;
712 }
713 QueryPerformanceCounter(&start);
714 while (true) {
715 std::this_thread::yield();
716 QueryPerformanceCounter(&now);
717 auto elapsed = now.QuadPart - start.QuadPart;
718 if (elapsed > desired) {
719 break;
720 }
721 }
722 }
723 #else
724 // std::this_thread::sleep_for should be high enough resolution on non-Windows systems
SleepABit()725 void SleepABit() { SleepFor(1e-3); }
726 #endif
727
BusyWait(double seconds,std::function<bool ()> predicate)728 void BusyWait(double seconds, std::function<bool()> predicate) {
729 const double period = 0.001;
730 for (int i = 0; !predicate() && i * period < seconds; ++i) {
731 SleepFor(period);
732 }
733 }
734
SleepAsync(double seconds)735 Future<> SleepAsync(double seconds) {
736 auto out = Future<>::Make();
737 std::thread([out, seconds]() mutable {
738 SleepFor(seconds);
739 out.MarkFinished();
740 }).detach();
741 return out;
742 }
743
SleepABitAsync()744 Future<> SleepABitAsync() {
745 auto out = Future<>::Make();
746 std::thread([out]() mutable {
747 SleepABit();
748 out.MarkFinished();
749 }).detach();
750 return out;
751 }
752
753 ///////////////////////////////////////////////////////////////////////////
754 // Extension types
755
ExtensionEquals(const ExtensionType & other) const756 bool UuidType::ExtensionEquals(const ExtensionType& other) const {
757 return (other.extension_name() == this->extension_name());
758 }
759
MakeArray(std::shared_ptr<ArrayData> data) const760 std::shared_ptr<Array> UuidType::MakeArray(std::shared_ptr<ArrayData> data) const {
761 DCHECK_EQ(data->type->id(), Type::EXTENSION);
762 DCHECK_EQ("uuid", static_cast<const ExtensionType&>(*data->type).extension_name());
763 return std::make_shared<UuidArray>(data);
764 }
765
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const766 Result<std::shared_ptr<DataType>> UuidType::Deserialize(
767 std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
768 if (serialized != "uuid-serialized") {
769 return Status::Invalid("Type identifier did not match: '", serialized, "'");
770 }
771 if (!storage_type->Equals(*fixed_size_binary(16))) {
772 return Status::Invalid("Invalid storage type for UuidType: ",
773 storage_type->ToString());
774 }
775 return std::make_shared<UuidType>();
776 }
777
ExtensionEquals(const ExtensionType & other) const778 bool SmallintType::ExtensionEquals(const ExtensionType& other) const {
779 return (other.extension_name() == this->extension_name());
780 }
781
MakeArray(std::shared_ptr<ArrayData> data) const782 std::shared_ptr<Array> SmallintType::MakeArray(std::shared_ptr<ArrayData> data) const {
783 DCHECK_EQ(data->type->id(), Type::EXTENSION);
784 DCHECK_EQ("smallint", static_cast<const ExtensionType&>(*data->type).extension_name());
785 return std::make_shared<SmallintArray>(data);
786 }
787
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const788 Result<std::shared_ptr<DataType>> SmallintType::Deserialize(
789 std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
790 if (serialized != "smallint") {
791 return Status::Invalid("Type identifier did not match: '", serialized, "'");
792 }
793 if (!storage_type->Equals(*int16())) {
794 return Status::Invalid("Invalid storage type for SmallintType: ",
795 storage_type->ToString());
796 }
797 return std::make_shared<SmallintType>();
798 }
799
ExtensionEquals(const ExtensionType & other) const800 bool DictExtensionType::ExtensionEquals(const ExtensionType& other) const {
801 return (other.extension_name() == this->extension_name());
802 }
803
MakeArray(std::shared_ptr<ArrayData> data) const804 std::shared_ptr<Array> DictExtensionType::MakeArray(
805 std::shared_ptr<ArrayData> data) const {
806 DCHECK_EQ(data->type->id(), Type::EXTENSION);
807 DCHECK(ExtensionEquals(checked_cast<const ExtensionType&>(*data->type)));
808 // No need for a specific ExtensionArray derived class
809 return std::make_shared<ExtensionArray>(data);
810 }
811
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const812 Result<std::shared_ptr<DataType>> DictExtensionType::Deserialize(
813 std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
814 if (serialized != "dict-extension-serialized") {
815 return Status::Invalid("Type identifier did not match: '", serialized, "'");
816 }
817 if (!storage_type->Equals(*storage_type_)) {
818 return Status::Invalid("Invalid storage type for DictExtensionType: ",
819 storage_type->ToString());
820 }
821 return std::make_shared<DictExtensionType>();
822 }
823
ExtensionEquals(const ExtensionType & other) const824 bool Complex128Type::ExtensionEquals(const ExtensionType& other) const {
825 return (other.extension_name() == this->extension_name());
826 }
827
MakeArray(std::shared_ptr<ArrayData> data) const828 std::shared_ptr<Array> Complex128Type::MakeArray(std::shared_ptr<ArrayData> data) const {
829 DCHECK_EQ(data->type->id(), Type::EXTENSION);
830 DCHECK(ExtensionEquals(checked_cast<const ExtensionType&>(*data->type)));
831 return std::make_shared<Complex128Array>(data);
832 }
833
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const834 Result<std::shared_ptr<DataType>> Complex128Type::Deserialize(
835 std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
836 if (serialized != "complex128-serialized") {
837 return Status::Invalid("Type identifier did not match: '", serialized, "'");
838 }
839 if (!storage_type->Equals(*storage_type_)) {
840 return Status::Invalid("Invalid storage type for Complex128Type: ",
841 storage_type->ToString());
842 }
843 return std::make_shared<Complex128Type>();
844 }
845
uuid()846 std::shared_ptr<DataType> uuid() { return std::make_shared<UuidType>(); }
847
smallint()848 std::shared_ptr<DataType> smallint() { return std::make_shared<SmallintType>(); }
849
dict_extension_type()850 std::shared_ptr<DataType> dict_extension_type() {
851 return std::make_shared<DictExtensionType>();
852 }
853
complex128()854 std::shared_ptr<DataType> complex128() { return std::make_shared<Complex128Type>(); }
855
MakeComplex128(const std::shared_ptr<Array> & real,const std::shared_ptr<Array> & imag)856 std::shared_ptr<Array> MakeComplex128(const std::shared_ptr<Array>& real,
857 const std::shared_ptr<Array>& imag) {
858 auto type = complex128();
859 std::shared_ptr<Array> storage(
860 new StructArray(checked_cast<const ExtensionType&>(*type).storage_type(),
861 real->length(), {real, imag}));
862 return ExtensionType::WrapArray(type, storage);
863 }
864
ExampleUuid()865 std::shared_ptr<Array> ExampleUuid() {
866 auto arr = ArrayFromJSON(
867 fixed_size_binary(16),
868 "[null, \"abcdefghijklmno0\", \"abcdefghijklmno1\", \"abcdefghijklmno2\"]");
869 return ExtensionType::WrapArray(uuid(), arr);
870 }
871
ExampleSmallint()872 std::shared_ptr<Array> ExampleSmallint() {
873 auto arr = ArrayFromJSON(int16(), "[-32768, null, 1, 2, 3, 4, 32767]");
874 return ExtensionType::WrapArray(smallint(), arr);
875 }
876
ExampleDictExtension()877 std::shared_ptr<Array> ExampleDictExtension() {
878 auto arr = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, null, 1]",
879 R"(["foo", "bar"])");
880 return ExtensionType::WrapArray(dict_extension_type(), arr);
881 }
882
ExampleComplex128()883 std::shared_ptr<Array> ExampleComplex128() {
884 auto arr = ArrayFromJSON(struct_({field("", float64()), field("", float64())}),
885 "[[1.0, -2.5], null, [3.0, -4.5]]");
886 return ExtensionType::WrapArray(complex128(), arr);
887 }
888
ExtensionTypeGuard(const std::shared_ptr<DataType> & type)889 ExtensionTypeGuard::ExtensionTypeGuard(const std::shared_ptr<DataType>& type)
890 : ExtensionTypeGuard(DataTypeVector{type}) {}
891
ExtensionTypeGuard(const DataTypeVector & types)892 ExtensionTypeGuard::ExtensionTypeGuard(const DataTypeVector& types) {
893 for (const auto& type : types) {
894 ARROW_CHECK_EQ(type->id(), Type::EXTENSION);
895 auto ext_type = checked_pointer_cast<ExtensionType>(type);
896
897 ARROW_CHECK_OK(RegisterExtensionType(ext_type));
898 extension_names_.push_back(ext_type->extension_name());
899 DCHECK(!extension_names_.back().empty());
900 }
901 }
902
~ExtensionTypeGuard()903 ExtensionTypeGuard::~ExtensionTypeGuard() {
904 for (const auto& name : extension_names_) {
905 ARROW_CHECK_OK(UnregisterExtensionType(name));
906 }
907 }
908
909 class GatingTask::Impl : public std::enable_shared_from_this<GatingTask::Impl> {
910 public:
Impl(double timeout_seconds)911 explicit Impl(double timeout_seconds)
912 : timeout_seconds_(timeout_seconds), status_(), unlocked_(false) {
913 unlocked_future_ = Future<>::Make();
914 }
915
~Impl()916 ~Impl() {
917 if (num_running_ != num_launched_) {
918 ADD_FAILURE()
919 << "A GatingTask instance was destroyed but some underlying tasks did not "
920 "start running"
921 << std::endl;
922 } else if (num_finished_ != num_launched_) {
923 ADD_FAILURE()
924 << "A GatingTask instance was destroyed but some underlying tasks did not "
925 "finish running"
926 << std::endl;
927 }
928 }
929
Task()930 std::function<void()> Task() {
931 num_launched_++;
932 auto self = shared_from_this();
933 return [self] { self->RunTask(); };
934 }
935
AsyncTask()936 Future<> AsyncTask() {
937 num_launched_++;
938 num_running_++;
939 /// TODO(ARROW-13004) Could maybe implement this check with future chains
940 /// if we check to see if the future has been "consumed" or not
941 num_finished_++;
942 return unlocked_future_;
943 }
944
RunTask()945 void RunTask() {
946 std::unique_lock<std::mutex> lk(mx_);
947 num_running_++;
948 running_cv_.notify_all();
949 if (!unlocked_cv_.wait_for(
950 lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
951 [this] { return unlocked_; })) {
952 status_ &= Status::Invalid("Timed out (" + std::to_string(timeout_seconds_) + "," +
953 std::to_string(unlocked_) +
954 " seconds) waiting for the gating task to be unlocked");
955 }
956 num_finished_++;
957 }
958
WaitForRunning(int count)959 Status WaitForRunning(int count) {
960 std::unique_lock<std::mutex> lk(mx_);
961 if (running_cv_.wait_for(
962 lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
963 [this, count] { return num_running_ >= count; })) {
964 return Status::OK();
965 }
966 return Status::Invalid("Timed out waiting for tasks to launch");
967 }
968
Unlock()969 Status Unlock() {
970 std::lock_guard<std::mutex> lk(mx_);
971 unlocked_ = true;
972 unlocked_cv_.notify_all();
973 unlocked_future_.MarkFinished();
974 return status_;
975 }
976
977 private:
978 double timeout_seconds_;
979 Status status_;
980 bool unlocked_;
981 std::atomic<int> num_launched_{0};
982 int num_running_ = 0;
983 int num_finished_ = 0;
984 std::mutex mx_;
985 std::condition_variable running_cv_;
986 std::condition_variable unlocked_cv_;
987 Future<> unlocked_future_;
988 };
989
GatingTask(double timeout_seconds)990 GatingTask::GatingTask(double timeout_seconds) : impl_(new Impl(timeout_seconds)) {}
991
~GatingTask()992 GatingTask::~GatingTask() {}
993
Task()994 std::function<void()> GatingTask::Task() { return impl_->Task(); }
995
AsyncTask()996 Future<> GatingTask::AsyncTask() { return impl_->AsyncTask(); }
997
Unlock()998 Status GatingTask::Unlock() { return impl_->Unlock(); }
999
WaitForRunning(int count)1000 Status GatingTask::WaitForRunning(int count) { return impl_->WaitForRunning(count); }
1001
Make(double timeout_seconds)1002 std::shared_ptr<GatingTask> GatingTask::Make(double timeout_seconds) {
1003 return std::make_shared<GatingTask>(timeout_seconds);
1004 }
1005
1006 } // namespace arrow
1007