1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <algorithm>
19 #include <functional>
20 #include <limits>
21 #include <memory>
22 #include <ostream>
23 #include <sstream>
24 #include <string>
25 #include <vector>
26
27 #include <gmock/gmock-matchers.h>
28
29 #include "arrow/array/array_decimal.h"
30 #include "arrow/array/concatenate.h"
31 #include "arrow/compute/api_vector.h"
32 #include "arrow/compute/kernels/test_util.h"
33 #include "arrow/result.h"
34 #include "arrow/table.h"
35 #include "arrow/testing/gtest_common.h"
36 #include "arrow/testing/gtest_util.h"
37 #include "arrow/testing/random.h"
38 #include "arrow/testing/util.h"
39 #include "arrow/type_traits.h"
40 #include "arrow/util/logging.h"
41
42 namespace arrow {
43
44 using internal::checked_cast;
45 using internal::checked_pointer_cast;
46
47 namespace compute {
48
AllOrders()49 std::vector<SortOrder> AllOrders() {
50 return {SortOrder::Ascending, SortOrder::Descending};
51 }
52
AllNullPlacements()53 std::vector<NullPlacement> AllNullPlacements() {
54 return {NullPlacement::AtEnd, NullPlacement::AtStart};
55 }
56
operator <<(std::ostream & os,NullPlacement null_placement)57 std::ostream& operator<<(std::ostream& os, NullPlacement null_placement) {
58 os << (null_placement == NullPlacement::AtEnd ? "AtEnd" : "AtStart");
59 return os;
60 }
61
62 // ----------------------------------------------------------------------
63 // Tests for NthToIndices
64
65 template <typename ArrayType>
GetLogicalValue(const ArrayType & array,uint64_t index)66 auto GetLogicalValue(const ArrayType& array, uint64_t index)
67 -> decltype(array.GetView(index)) {
68 return array.GetView(index);
69 }
70
GetLogicalValue(const Decimal128Array & array,uint64_t index)71 Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
72 return Decimal128(array.Value(index));
73 }
74
GetLogicalValue(const Decimal256Array & array,uint64_t index)75 Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
76 return Decimal256(array.Value(index));
77 }
78
79 template <typename ArrayType>
80 struct ThreeWayComparator {
81 SortOrder order;
82 NullPlacement null_placement;
83
operator ()arrow::compute::ThreeWayComparator84 int operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const {
85 return (*this)(array, array, lhs, rhs);
86 }
87
88 // Return -1 if L < R, 0 if L == R, 1 if L > R
operator ()arrow::compute::ThreeWayComparator89 int operator()(const ArrayType& left, const ArrayType& right, uint64_t lhs,
90 uint64_t rhs) const {
91 const bool lhs_is_null = left.IsNull(lhs);
92 const bool rhs_is_null = right.IsNull(rhs);
93 if (lhs_is_null && rhs_is_null) return 0;
94 if (lhs_is_null) {
95 return null_placement == NullPlacement::AtStart ? -1 : 1;
96 }
97 if (rhs_is_null) {
98 return null_placement == NullPlacement::AtStart ? 1 : -1;
99 }
100 const auto lval = GetLogicalValue(left, lhs);
101 const auto rval = GetLogicalValue(right, rhs);
102 if (is_floating_type<typename ArrayType::TypeClass>::value) {
103 const bool lhs_isnan = lval != lval;
104 const bool rhs_isnan = rval != rval;
105 if (lhs_isnan && rhs_isnan) return 0;
106 if (lhs_isnan) {
107 return null_placement == NullPlacement::AtStart ? -1 : 1;
108 }
109 if (rhs_isnan) {
110 return null_placement == NullPlacement::AtStart ? 1 : -1;
111 }
112 }
113 if (lval == rval) return 0;
114 if (lval < rval) {
115 return order == SortOrder::Ascending ? -1 : 1;
116 } else {
117 return order == SortOrder::Ascending ? 1 : -1;
118 }
119 }
120 };
121
122 template <typename ArrayType>
123 struct NthComparator {
124 ThreeWayComparator<ArrayType> three_way;
125
NthComparatorarrow::compute::NthComparator126 explicit NthComparator(NullPlacement null_placement)
127 : three_way({SortOrder::Ascending, null_placement}) {}
128
129 // Return true iff L <= R
operator ()arrow::compute::NthComparator130 bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const {
131 // lhs <= rhs
132 return three_way(array, lhs, rhs) <= 0;
133 }
134 };
135
136 template <typename ArrayType>
137 struct SortComparator {
138 ThreeWayComparator<ArrayType> three_way;
139
SortComparatorarrow::compute::SortComparator140 explicit SortComparator(SortOrder order, NullPlacement null_placement)
141 : three_way({order, null_placement}) {}
142
operator ()arrow::compute::SortComparator143 bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const {
144 const int r = three_way(array, lhs, rhs);
145 if (r != 0) return r < 0;
146 return lhs < rhs;
147 }
148 };
149
150 template <typename ArrowType>
151 class TestNthToIndicesBase : public TestBase {
152 using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
153
154 protected:
Validate(const ArrayType & array,int n,NullPlacement null_placement,UInt64Array & offsets)155 void Validate(const ArrayType& array, int n, NullPlacement null_placement,
156 UInt64Array& offsets) {
157 if (n >= array.length()) {
158 for (int i = 0; i < array.length(); ++i) {
159 ASSERT_TRUE(offsets.Value(i) == static_cast<uint64_t>(i));
160 }
161 } else {
162 NthComparator<ArrayType> compare{null_placement};
163 uint64_t nth = offsets.Value(n);
164
165 for (int i = 0; i < n; ++i) {
166 uint64_t lhs = offsets.Value(i);
167 ASSERT_TRUE(compare(array, lhs, nth));
168 }
169 for (int i = n + 1; i < array.length(); ++i) {
170 uint64_t rhs = offsets.Value(i);
171 ASSERT_TRUE(compare(array, nth, rhs));
172 }
173 }
174 }
175
AssertNthToIndicesArray(const std::shared_ptr<Array> & values,int n,NullPlacement null_placement)176 void AssertNthToIndicesArray(const std::shared_ptr<Array>& values, int n,
177 NullPlacement null_placement) {
178 ARROW_SCOPED_TRACE("n = ", n, ", null_placement = ", null_placement);
179 ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
180 NthToIndices(*values, PartitionNthOptions(n, null_placement)));
181 // null_count field should have been initialized to 0, for convenience
182 ASSERT_EQ(offsets->data()->null_count, 0);
183 ValidateOutput(*offsets);
184 Validate(*checked_pointer_cast<ArrayType>(values), n, null_placement,
185 *checked_pointer_cast<UInt64Array>(offsets));
186 }
187
AssertNthToIndicesArray(const std::shared_ptr<Array> & values,int n)188 void AssertNthToIndicesArray(const std::shared_ptr<Array>& values, int n) {
189 for (auto null_placement : AllNullPlacements()) {
190 AssertNthToIndicesArray(values, n, null_placement);
191 }
192 }
193
AssertNthToIndicesJson(const std::string & values,int n)194 void AssertNthToIndicesJson(const std::string& values, int n) {
195 AssertNthToIndicesArray(ArrayFromJSON(GetType(), values), n);
196 }
197
198 virtual std::shared_ptr<DataType> GetType() = 0;
199 };
200
201 template <typename ArrowType>
202 class TestNthToIndices : public TestNthToIndicesBase<ArrowType> {
203 protected:
GetType()204 std::shared_ptr<DataType> GetType() override {
205 return default_type_instance<ArrowType>();
206 }
207 };
208
209 template <typename ArrowType>
210 class TestNthToIndicesForReal : public TestNthToIndices<ArrowType> {};
211 TYPED_TEST_SUITE(TestNthToIndicesForReal, RealArrowTypes);
212
213 template <typename ArrowType>
214 class TestNthToIndicesForIntegral : public TestNthToIndices<ArrowType> {};
215 TYPED_TEST_SUITE(TestNthToIndicesForIntegral, IntegralArrowTypes);
216
217 template <typename ArrowType>
218 class TestNthToIndicesForBool : public TestNthToIndices<ArrowType> {};
219 TYPED_TEST_SUITE(TestNthToIndicesForBool, ::testing::Types<BooleanType>);
220
221 template <typename ArrowType>
222 class TestNthToIndicesForTemporal : public TestNthToIndices<ArrowType> {};
223 TYPED_TEST_SUITE(TestNthToIndicesForTemporal, TemporalArrowTypes);
224
225 template <typename ArrowType>
226 class TestNthToIndicesForDecimal : public TestNthToIndicesBase<ArrowType> {
GetType()227 std::shared_ptr<DataType> GetType() override {
228 return std::make_shared<ArrowType>(5, 2);
229 }
230 };
231 TYPED_TEST_SUITE(TestNthToIndicesForDecimal, DecimalArrowTypes);
232
233 template <typename ArrowType>
234 class TestNthToIndicesForStrings : public TestNthToIndices<ArrowType> {};
235 TYPED_TEST_SUITE(TestNthToIndicesForStrings, testing::Types<StringType>);
236
TYPED_TEST(TestNthToIndicesForReal,NthToIndicesDoesNotProvideDefaultOptions)237 TYPED_TEST(TestNthToIndicesForReal, NthToIndicesDoesNotProvideDefaultOptions) {
238 auto input = ArrayFromJSON(this->GetType(), "[null, 1, 3.3, null, 2, 5.3]");
239 ASSERT_RAISES(Invalid, CallFunction("partition_nth_indices", {input}));
240 }
241
TYPED_TEST(TestNthToIndicesForReal,Real)242 TYPED_TEST(TestNthToIndicesForReal, Real) {
243 this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 0);
244 this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 2);
245 this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 5);
246 this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 6);
247
248 this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 0);
249 this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 1);
250 this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 2);
251 this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 3);
252 this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 4);
253 this->AssertNthToIndicesJson("[NaN, 2, null, 3, 1]", 3);
254 this->AssertNthToIndicesJson("[NaN, 2, null, 3, 1]", 4);
255
256 this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 0);
257 this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 1);
258 this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 2);
259 this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 3);
260 this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 4);
261 }
262
TYPED_TEST(TestNthToIndicesForIntegral,Integral)263 TYPED_TEST(TestNthToIndicesForIntegral, Integral) {
264 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 0);
265 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 2);
266 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 5);
267 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 6);
268 }
269
TYPED_TEST(TestNthToIndicesForBool,Bool)270 TYPED_TEST(TestNthToIndicesForBool, Bool) {
271 this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 0);
272 this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 2);
273 this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 5);
274 this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 6);
275 }
276
TYPED_TEST(TestNthToIndicesForTemporal,Temporal)277 TYPED_TEST(TestNthToIndicesForTemporal, Temporal) {
278 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 0);
279 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 2);
280 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 5);
281 this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 6);
282 }
283
TYPED_TEST(TestNthToIndicesForDecimal,Decimal)284 TYPED_TEST(TestNthToIndicesForDecimal, Decimal) {
285 const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])";
286 this->AssertNthToIndicesJson(values, 0);
287 this->AssertNthToIndicesJson(values, 2);
288 this->AssertNthToIndicesJson(values, 4);
289 this->AssertNthToIndicesJson(values, 5);
290 }
291
TYPED_TEST(TestNthToIndicesForStrings,Strings)292 TYPED_TEST(TestNthToIndicesForStrings, Strings) {
293 this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 0);
294 this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 2);
295 this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 5);
296 this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 6);
297 }
298
TEST(TestNthToIndices,Null)299 TEST(TestNthToIndices, Null) {
300 ASSERT_OK_AND_ASSIGN(auto arr, MakeArrayOfNull(null(), 6));
301 auto expected = ArrayFromJSON(uint64(), "[0, 1, 2, 3, 4, 5]");
302 for (const auto null_placement : AllNullPlacements()) {
303 for (const auto n : {0, 1, 2, 3, 4, 5, 6}) {
304 ASSERT_OK_AND_ASSIGN(auto actual,
305 NthToIndices(*arr, PartitionNthOptions(n, null_placement)));
306 AssertArraysEqual(*expected, *actual, /*verbose=*/true);
307 }
308 }
309 }
310
311 template <typename ArrowType>
312 class TestNthToIndicesRandom : public TestNthToIndicesBase<ArrowType> {
313 public:
GetType()314 std::shared_ptr<DataType> GetType() override {
315 EXPECT_TRUE(0) << "shouldn't be used";
316 return nullptr;
317 }
318 };
319
320 using NthToIndicesableTypes =
321 ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
322 Int32Type, Int64Type, FloatType, DoubleType, Decimal128Type,
323 StringType>;
324
325 TYPED_TEST_SUITE(TestNthToIndicesRandom, NthToIndicesableTypes);
326
TYPED_TEST(TestNthToIndicesRandom,RandomValues)327 TYPED_TEST(TestNthToIndicesRandom, RandomValues) {
328 Random<TypeParam> rand(0x61549225);
329 int length = 100;
330 for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
331 // Try n from 0 to out of bound
332 for (int n = 0; n <= length; ++n) {
333 auto array = rand.Generate(length, null_probability);
334 this->AssertNthToIndicesArray(array, n);
335 }
336 }
337 }
338
339 // ----------------------------------------------------------------------
340 // Tests for SortToIndices
341
342 template <typename T>
AssertSortIndices(const std::shared_ptr<T> & input,SortOrder order,NullPlacement null_placement,const std::shared_ptr<Array> & expected)343 void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
344 NullPlacement null_placement,
345 const std::shared_ptr<Array>& expected) {
346 ArraySortOptions options(order, null_placement);
347 ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options));
348 ValidateOutput(*actual);
349 AssertArraysEqual(*expected, *actual, /*verbose=*/true);
350 }
351
352 template <typename T>
AssertSortIndices(const std::shared_ptr<T> & input,const SortOptions & options,const std::shared_ptr<Array> & expected)353 void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
354 const std::shared_ptr<Array>& expected) {
355 ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*input), options));
356 ValidateOutput(*actual);
357 AssertArraysEqual(*expected, *actual, /*verbose=*/true);
358 }
359
360 template <typename T>
AssertSortIndices(const std::shared_ptr<T> & input,const SortOptions & options,const std::string & expected)361 void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
362 const std::string& expected) {
363 AssertSortIndices(input, options, ArrayFromJSON(uint64(), expected));
364 }
365
366 template <typename T>
AssertSortIndices(const std::shared_ptr<T> & input,SortOrder order,NullPlacement null_placement,const std::string & expected)367 void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
368 NullPlacement null_placement, const std::string& expected) {
369 AssertSortIndices(input, order, null_placement, ArrayFromJSON(uint64(), expected));
370 }
371
AssertSortIndices(const std::shared_ptr<DataType> & type,const std::string & values,SortOrder order,NullPlacement null_placement,const std::string & expected)372 void AssertSortIndices(const std::shared_ptr<DataType>& type, const std::string& values,
373 SortOrder order, NullPlacement null_placement,
374 const std::string& expected) {
375 AssertSortIndices(ArrayFromJSON(type, values), order, null_placement,
376 ArrayFromJSON(uint64(), expected));
377 }
378
379 class TestArraySortIndicesBase : public TestBase {
380 public:
381 virtual std::shared_ptr<DataType> type() = 0;
382
AssertSortIndices(const std::string & values,SortOrder order,NullPlacement null_placement,const std::string & expected)383 virtual void AssertSortIndices(const std::string& values, SortOrder order,
384 NullPlacement null_placement,
385 const std::string& expected) {
386 arrow::compute::AssertSortIndices(this->type(), values, order, null_placement,
387 expected);
388 }
389
AssertSortIndices(const std::string & values,const std::string & expected)390 virtual void AssertSortIndices(const std::string& values, const std::string& expected) {
391 AssertSortIndices(values, SortOrder::Ascending, NullPlacement::AtEnd, expected);
392 }
393 };
394
395 template <typename ArrowType>
396 class TestArraySortIndices : public TestArraySortIndicesBase {
397 public:
type()398 std::shared_ptr<DataType> type() override {
399 // Will choose default parameters for temporal types
400 return std::make_shared<ArrowType>();
401 }
402 };
403
404 template <typename ArrowType>
405 class TestArraySortIndicesForReal : public TestArraySortIndices<ArrowType> {};
406 TYPED_TEST_SUITE(TestArraySortIndicesForReal, RealArrowTypes);
407
408 template <typename ArrowType>
409 class TestArraySortIndicesForBool : public TestArraySortIndices<ArrowType> {};
410 TYPED_TEST_SUITE(TestArraySortIndicesForBool, ::testing::Types<BooleanType>);
411
412 template <typename ArrowType>
413 class TestArraySortIndicesForIntegral : public TestArraySortIndices<ArrowType> {};
414 TYPED_TEST_SUITE(TestArraySortIndicesForIntegral, IntegralArrowTypes);
415
416 template <typename ArrowType>
417 class TestArraySortIndicesForTemporal : public TestArraySortIndices<ArrowType> {};
418 TYPED_TEST_SUITE(TestArraySortIndicesForTemporal, TemporalArrowTypes);
419
420 using StringSortTestTypes = testing::Types<StringType, LargeStringType>;
421
422 template <typename ArrowType>
423 class TestArraySortIndicesForStrings : public TestArraySortIndices<ArrowType> {};
424 TYPED_TEST_SUITE(TestArraySortIndicesForStrings, StringSortTestTypes);
425
426 class TestArraySortIndicesForFixedSizeBinary : public TestArraySortIndicesBase {
427 public:
type()428 std::shared_ptr<DataType> type() override { return fixed_size_binary(3); }
429 };
430
TYPED_TEST(TestArraySortIndicesForReal,SortReal)431 TYPED_TEST(TestArraySortIndicesForReal, SortReal) {
432 for (auto null_placement : AllNullPlacements()) {
433 for (auto order : AllOrders()) {
434 this->AssertSortIndices("[]", order, null_placement, "[]");
435 this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
436 }
437 this->AssertSortIndices("[3.4, 2.6, 6.3]", SortOrder::Ascending, null_placement,
438 "[1, 0, 2]");
439 this->AssertSortIndices("[1.1, 2.4, 3.5, 4.3, 5.1, 6.8, 7.3]", SortOrder::Ascending,
440 null_placement, "[0, 1, 2, 3, 4, 5, 6]");
441 this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
442 "[6, 5, 4, 3, 2, 1, 0]");
443 this->AssertSortIndices("[10.4, 12, 4.2, 50, 50.3, 32, 11]", SortOrder::Ascending,
444 null_placement, "[2, 0, 6, 1, 5, 3, 4]");
445 }
446
447 this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Ascending,
448 NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
449 this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Ascending,
450 NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
451 this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Descending,
452 NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
453 this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Descending,
454 NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
455
456 this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Ascending,
457 NullPlacement::AtEnd, "[3, 4, 0, 1, 2, 5]");
458 this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Ascending,
459 NullPlacement::AtStart, "[5, 2, 3, 4, 0, 1]");
460 this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Descending,
461 NullPlacement::AtEnd, "[1, 0, 4, 3, 2, 5]");
462 this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Descending,
463 NullPlacement::AtStart, "[5, 2, 1, 0, 4, 3]");
464
465 this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Ascending,
466 NullPlacement::AtEnd, "[4, 1, 3, 0, 2]");
467 this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Ascending,
468 NullPlacement::AtStart, "[0, 2, 4, 1, 3]");
469 this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Descending,
470 NullPlacement::AtEnd, "[3, 1, 4, 0, 2]");
471 this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Descending,
472 NullPlacement::AtStart, "[0, 2, 3, 1, 4]");
473
474 this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Ascending,
475 NullPlacement::AtEnd, "[1, 2, 0, 3]");
476 this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Ascending,
477 NullPlacement::AtStart, "[0, 3, 1, 2]");
478 this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Descending,
479 NullPlacement::AtEnd, "[1, 2, 0, 3]");
480 this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Descending,
481 NullPlacement::AtStart, "[0, 3, 1, 2]");
482 }
483
TYPED_TEST(TestArraySortIndicesForIntegral,SortIntegral)484 TYPED_TEST(TestArraySortIndicesForIntegral, SortIntegral) {
485 for (auto null_placement : AllNullPlacements()) {
486 for (auto order : AllOrders()) {
487 this->AssertSortIndices("[]", order, null_placement, "[]");
488 this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
489 }
490 this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement,
491 "[0, 1, 2, 3, 4, 5, 6]");
492 this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
493 "[6, 5, 4, 3, 2, 1, 0]");
494
495 this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending,
496 null_placement, "[2, 0, 6, 1, 5, 3, 4]");
497 this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending,
498 null_placement, "[3, 4, 5, 1, 6, 0, 2]");
499 }
500
501 // Values with a small range (use a counting sort)
502 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
503 NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
504 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
505 NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
506 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
507 NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
508 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
509 NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
510 }
511
TYPED_TEST(TestArraySortIndicesForBool,SortBool)512 TYPED_TEST(TestArraySortIndicesForBool, SortBool) {
513 for (auto null_placement : AllNullPlacements()) {
514 for (auto order : AllOrders()) {
515 this->AssertSortIndices("[]", order, null_placement, "[]");
516 this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
517 }
518 this->AssertSortIndices("[true, true, false]", SortOrder::Ascending, null_placement,
519 "[2, 0, 1]");
520 this->AssertSortIndices("[false, false, false, true, true, true, true]",
521 SortOrder::Ascending, null_placement,
522 "[0, 1, 2, 3, 4, 5, 6]");
523 this->AssertSortIndices("[true, true, true, true, false, false, false]",
524 SortOrder::Ascending, null_placement,
525 "[4, 5, 6, 0, 1, 2, 3]");
526
527 this->AssertSortIndices("[false, true, false, true, true, false, false]",
528 SortOrder::Ascending, null_placement,
529 "[0, 2, 5, 6, 1, 3, 4]");
530 this->AssertSortIndices("[false, true, false, true, true, false, false]",
531 SortOrder::Descending, null_placement,
532 "[1, 3, 4, 0, 2, 5, 6]");
533 }
534
535 this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Ascending,
536 NullPlacement::AtEnd, "[2, 4, 1, 5, 0, 3]");
537 this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Ascending,
538 NullPlacement::AtStart, "[0, 3, 2, 4, 1, 5]");
539 this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Descending,
540 NullPlacement::AtEnd, "[1, 5, 2, 4, 0, 3]");
541 this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Descending,
542 NullPlacement::AtStart, "[0, 3, 1, 5, 2, 4]");
543 }
544
TYPED_TEST(TestArraySortIndicesForTemporal,SortTemporal)545 TYPED_TEST(TestArraySortIndicesForTemporal, SortTemporal) {
546 for (auto null_placement : AllNullPlacements()) {
547 for (auto order : AllOrders()) {
548 this->AssertSortIndices("[]", order, null_placement, "[]");
549 this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
550 }
551 this->AssertSortIndices("[3, 2, 6]", SortOrder::Ascending, null_placement,
552 "[1, 0, 2]");
553 this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement,
554 "[0, 1, 2, 3, 4, 5, 6]");
555 this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
556 "[6, 5, 4, 3, 2, 1, 0]");
557
558 this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending,
559 null_placement, "[2, 0, 6, 1, 5, 3, 4]");
560 this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending,
561 null_placement, "[3, 4, 5, 1, 6, 0, 2]");
562 }
563
564 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
565 NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
566 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
567 NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
568 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
569 NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
570 this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
571 NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
572 }
573
TYPED_TEST(TestArraySortIndicesForStrings,SortStrings)574 TYPED_TEST(TestArraySortIndicesForStrings, SortStrings) {
575 for (auto null_placement : AllNullPlacements()) {
576 for (auto order : AllOrders()) {
577 this->AssertSortIndices("[]", order, null_placement, "[]");
578 this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
579 }
580 this->AssertSortIndices(R"(["a", "b", "c"])", SortOrder::Ascending, null_placement,
581 "[0, 1, 2]");
582 this->AssertSortIndices(R"(["foo", "bar", "baz"])", SortOrder::Ascending,
583 null_placement, "[1, 2, 0]");
584 this->AssertSortIndices(R"(["testing", "sort", "for", "strings"])",
585 SortOrder::Ascending, null_placement, "[2, 1, 3, 0]");
586 }
587
588 const char* input = R"([null, "c", "b", null, "a", "b"])";
589 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
590 "[4, 2, 5, 1, 0, 3]");
591 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
592 "[0, 3, 4, 2, 5, 1]");
593 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
594 "[1, 2, 5, 4, 0, 3]");
595 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
596 "[0, 3, 1, 2, 5, 4]");
597 }
598
TEST_F(TestArraySortIndicesForFixedSizeBinary,SortFixedSizeBinary)599 TEST_F(TestArraySortIndicesForFixedSizeBinary, SortFixedSizeBinary) {
600 for (auto null_placement : AllNullPlacements()) {
601 for (auto order : AllOrders()) {
602 this->AssertSortIndices("[]", order, null_placement, "[]");
603 this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
604 }
605 this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Ascending,
606 null_placement, "[1, 0, 2]");
607 this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Descending,
608 null_placement, "[2, 0, 1]");
609 }
610
611 const char* input = R"([null, "ccc", "bbb", null, "aaa", "bbb"])";
612 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
613 "[4, 2, 5, 1, 0, 3]");
614 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
615 "[0, 3, 4, 2, 5, 1]");
616 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
617 "[1, 2, 5, 4, 0, 3]");
618 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
619 "[0, 3, 1, 2, 5, 4]");
620 }
621
622 template <typename ArrowType>
623 class TestArraySortIndicesForUInt8 : public TestArraySortIndices<ArrowType> {};
624 TYPED_TEST_SUITE(TestArraySortIndicesForUInt8, UInt8Type);
625
626 template <typename ArrowType>
627 class TestArraySortIndicesForInt8 : public TestArraySortIndices<ArrowType> {};
628 TYPED_TEST_SUITE(TestArraySortIndicesForInt8, Int8Type);
629
TYPED_TEST(TestArraySortIndicesForUInt8,SortUInt8)630 TYPED_TEST(TestArraySortIndicesForUInt8, SortUInt8) {
631 const char* input = "[255, null, 0, 255, 10, null, 128, 0]";
632 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
633 "[2, 7, 4, 6, 0, 3, 1, 5]");
634 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
635 "[1, 5, 2, 7, 4, 6, 0, 3]");
636 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
637 "[0, 3, 6, 4, 2, 7, 1, 5]");
638 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
639 "[1, 5, 0, 3, 6, 4, 2, 7]");
640 }
641
TYPED_TEST(TestArraySortIndicesForInt8,SortInt8)642 TYPED_TEST(TestArraySortIndicesForInt8, SortInt8) {
643 const char* input = "[127, null, -128, 127, 0, null, 10, -128]";
644 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
645 "[2, 7, 4, 6, 0, 3, 1, 5]");
646 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
647 "[1, 5, 2, 7, 4, 6, 0, 3]");
648 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
649 "[0, 3, 6, 4, 2, 7, 1, 5]");
650 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
651 "[1, 5, 0, 3, 6, 4, 2, 7]");
652 }
653
654 template <typename ArrowType>
655 class TestArraySortIndicesForInt64 : public TestArraySortIndices<ArrowType> {};
656 TYPED_TEST_SUITE(TestArraySortIndicesForInt64, Int64Type);
657
TYPED_TEST(TestArraySortIndicesForInt64,SortInt64)658 TYPED_TEST(TestArraySortIndicesForInt64, SortInt64) {
659 // Values with a large range (use a comparison-based sort)
660 const char* input =
661 "[null, -2000000000000000, 3000000000000000,"
662 " null, -1000000000000000, 5000000000000000]";
663 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
664 "[1, 4, 2, 5, 0, 3]");
665 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
666 "[0, 3, 1, 4, 2, 5]");
667 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
668 "[5, 2, 4, 1, 0, 3]");
669 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
670 "[0, 3, 5, 2, 4, 1]");
671 }
672
673 template <typename ArrowType>
674 class TestArraySortIndicesForDecimal : public TestArraySortIndicesBase {
675 public:
type()676 std::shared_ptr<DataType> type() override { return std::make_shared<ArrowType>(5, 2); }
677 };
678 TYPED_TEST_SUITE(TestArraySortIndicesForDecimal, DecimalArrowTypes);
679
TYPED_TEST(TestArraySortIndicesForDecimal,DecimalSortTestTypes)680 TYPED_TEST(TestArraySortIndicesForDecimal, DecimalSortTestTypes) {
681 const char* input = R"(["123.45", null, "-123.45", "456.78", "-456.78", null])";
682 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
683 "[4, 2, 0, 3, 1, 5]");
684 this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
685 "[1, 5, 4, 2, 0, 3]");
686 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
687 "[3, 0, 2, 4, 1, 5]");
688 this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
689 "[1, 5, 3, 0, 2, 4]");
690 }
691
TEST(TestArraySortIndices,NullType)692 TEST(TestArraySortIndices, NullType) {
693 auto chunked = ChunkedArrayFromJSON(null(), {"[null, null]", "[]", "[null]", "[null]"});
694 for (const auto null_placement : AllNullPlacements()) {
695 for (const auto order : AllOrders()) {
696 AssertSortIndices(null(), "[null, null, null, null]", order, null_placement,
697 "[0, 1, 2, 3]");
698 AssertSortIndices(chunked, order, null_placement, "[0, 1, 2, 3]");
699 }
700 }
701 }
702
TEST(TestArraySortIndices,TemporalTypeParameters)703 TEST(TestArraySortIndices, TemporalTypeParameters) {
704 std::vector<std::shared_ptr<DataType>> types;
705 for (auto unit : {TimeUnit::NANO, TimeUnit::MICRO, TimeUnit::MILLI, TimeUnit::SECOND}) {
706 types.push_back(duration(unit));
707 types.push_back(timestamp(unit));
708 types.push_back(timestamp(unit, "America/Phoenix"));
709 }
710 types.push_back(time64(TimeUnit::NANO));
711 types.push_back(time64(TimeUnit::MICRO));
712 types.push_back(time32(TimeUnit::MILLI));
713 types.push_back(time32(TimeUnit::SECOND));
714 for (const auto& ty : types) {
715 for (auto null_placement : AllNullPlacements()) {
716 for (auto order : AllOrders()) {
717 AssertSortIndices(ty, "[]", order, null_placement, "[]");
718 AssertSortIndices(ty, "[null, null]", order, null_placement, "[0, 1]");
719 }
720 AssertSortIndices(ty, "[3, 2, 6]", SortOrder::Ascending, null_placement,
721 "[1, 0, 2]");
722 AssertSortIndices(ty, "[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement,
723 "[0, 1, 2, 3, 4, 5, 6]");
724 AssertSortIndices(ty, "[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
725 "[6, 5, 4, 3, 2, 1, 0]");
726
727 AssertSortIndices(ty, "[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending,
728 null_placement, "[2, 0, 6, 1, 5, 3, 4]");
729 AssertSortIndices(ty, "[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending,
730 null_placement, "[3, 4, 5, 1, 6, 0, 2]");
731 }
732 AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
733 NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
734 AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
735 NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
736 AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Descending,
737 NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
738 AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Descending,
739 NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
740 }
741 }
742
743 template <typename ArrowType>
744 class TestArraySortIndicesRandom : public TestBase {};
745
746 template <typename ArrowType>
747 class TestArraySortIndicesRandomCount : public TestBase {};
748
749 template <typename ArrowType>
750 class TestArraySortIndicesRandomCompare : public TestBase {};
751
752 using SortIndicesableTypes =
753 ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
754 Int32Type, Int64Type, FloatType, DoubleType, StringType,
755 Decimal128Type, BooleanType>;
756
757 template <typename ArrayType>
ValidateSorted(const ArrayType & array,UInt64Array & offsets,SortOrder order,NullPlacement null_placement)758 void ValidateSorted(const ArrayType& array, UInt64Array& offsets, SortOrder order,
759 NullPlacement null_placement) {
760 ValidateOutput(array);
761 SortComparator<ArrayType> compare{order, null_placement};
762 for (int i = 1; i < array.length(); i++) {
763 uint64_t lhs = offsets.Value(i - 1);
764 uint64_t rhs = offsets.Value(i);
765 ASSERT_TRUE(compare(array, lhs, rhs));
766 }
767 }
768
769 TYPED_TEST_SUITE(TestArraySortIndicesRandom, SortIndicesableTypes);
770
TYPED_TEST(TestArraySortIndicesRandom,SortRandomValues)771 TYPED_TEST(TestArraySortIndicesRandom, SortRandomValues) {
772 using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
773
774 Random<TypeParam> rand(0x5487655);
775 int times = 5;
776 int length = 100;
777 for (int test = 0; test < times; test++) {
778 for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
779 auto array = rand.Generate(length, null_probability);
780 for (auto order : AllOrders()) {
781 for (auto null_placement : AllNullPlacements()) {
782 ArraySortOptions options(order, null_placement);
783 ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
784 SortIndices(*array, options));
785 ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
786 *checked_pointer_cast<UInt64Array>(offsets), order,
787 null_placement);
788 }
789 }
790 }
791 }
792 }
793
794 // Long array with small value range: counting sort
795 // - length >= 1024(CountCompareSorter::countsort_min_len_)
796 // - range <= 4096(CountCompareSorter::countsort_max_range_)
797 TYPED_TEST_SUITE(TestArraySortIndicesRandomCount, IntegralArrowTypes);
798
TYPED_TEST(TestArraySortIndicesRandomCount,SortRandomValuesCount)799 TYPED_TEST(TestArraySortIndicesRandomCount, SortRandomValuesCount) {
800 using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
801
802 RandomRange<TypeParam> rand(0x5487656);
803 int times = 5;
804 int length = 100;
805 int range = 2000;
806 for (int test = 0; test < times; test++) {
807 for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
808 auto array = rand.Generate(length, range, null_probability);
809 for (auto order : AllOrders()) {
810 for (auto null_placement : AllNullPlacements()) {
811 ArraySortOptions options(order, null_placement);
812 ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
813 SortIndices(*array, options));
814 ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
815 *checked_pointer_cast<UInt64Array>(offsets), order,
816 null_placement);
817 }
818 }
819 }
820 }
821 }
822
823 // Long array with big value range: std::stable_sort
824 TYPED_TEST_SUITE(TestArraySortIndicesRandomCompare, IntegralArrowTypes);
825
TYPED_TEST(TestArraySortIndicesRandomCompare,SortRandomValuesCompare)826 TYPED_TEST(TestArraySortIndicesRandomCompare, SortRandomValuesCompare) {
827 using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
828
829 Random<TypeParam> rand(0x5487657);
830 int times = 5;
831 int length = 100;
832 for (int test = 0; test < times; test++) {
833 for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
834 auto array = rand.Generate(length, null_probability);
835 for (auto order : AllOrders()) {
836 for (auto null_placement : AllNullPlacements()) {
837 ArraySortOptions options(order, null_placement);
838 ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
839 SortIndices(*array, options));
840 ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
841 *checked_pointer_cast<UInt64Array>(offsets), order,
842 null_placement);
843 }
844 }
845 }
846 }
847 }
848
849 // Test basic cases for chunked array.
850 class TestChunkedArraySortIndices : public ::testing::Test {};
851
TEST_F(TestChunkedArraySortIndices,Null)852 TEST_F(TestChunkedArraySortIndices, Null) {
853 auto chunked_array = ChunkedArrayFromJSON(uint8(), {
854 "[null, 1]",
855 "[3, null, 2]",
856 "[1]",
857 });
858 AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd,
859 "[1, 5, 4, 2, 0, 3]");
860 AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart,
861 "[0, 3, 1, 5, 4, 2]");
862 AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd,
863 "[2, 4, 1, 5, 0, 3]");
864 AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart,
865 "[0, 3, 2, 4, 1, 5]");
866 }
867
TEST_F(TestChunkedArraySortIndices,NaN)868 TEST_F(TestChunkedArraySortIndices, NaN) {
869 auto chunked_array = ChunkedArrayFromJSON(float32(), {
870 "[null, 1]",
871 "[3, null, NaN]",
872 "[NaN, 1]",
873 });
874 AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd,
875 "[1, 6, 2, 4, 5, 0, 3]");
876 AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart,
877 "[0, 3, 4, 5, 1, 6, 2]");
878 AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd,
879 "[2, 1, 6, 4, 5, 0, 3]");
880 AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart,
881 "[0, 3, 4, 5, 2, 1, 6]");
882 }
883
884 // Tests for temporal types
885 template <typename ArrowType>
886 class TestChunkedArraySortIndicesForTemporal : public TestChunkedArraySortIndices {
887 protected:
GetType()888 std::shared_ptr<DataType> GetType() { return default_type_instance<ArrowType>(); }
889 };
890 TYPED_TEST_SUITE(TestChunkedArraySortIndicesForTemporal, TemporalArrowTypes);
891
TYPED_TEST(TestChunkedArraySortIndicesForTemporal,NoNull)892 TYPED_TEST(TestChunkedArraySortIndicesForTemporal, NoNull) {
893 auto type = this->GetType();
894 auto chunked_array = ChunkedArrayFromJSON(type, {
895 "[0, 1]",
896 "[3, 2, 1]",
897 "[5, 0]",
898 });
899 for (auto null_placement : AllNullPlacements()) {
900 AssertSortIndices(chunked_array, SortOrder::Ascending, null_placement,
901 "[0, 6, 1, 4, 3, 2, 5]");
902 AssertSortIndices(chunked_array, SortOrder::Descending, null_placement,
903 "[5, 2, 3, 1, 4, 0, 6]");
904 }
905 }
906
907 // Tests for decimal types
908 template <typename ArrowType>
909 class TestChunkedArraySortIndicesForDecimal : public TestChunkedArraySortIndices {
910 protected:
GetType()911 std::shared_ptr<DataType> GetType() { return std::make_shared<ArrowType>(5, 2); }
912 };
913 TYPED_TEST_SUITE(TestChunkedArraySortIndicesForDecimal, DecimalArrowTypes);
914
TYPED_TEST(TestChunkedArraySortIndicesForDecimal,Basics)915 TYPED_TEST(TestChunkedArraySortIndicesForDecimal, Basics) {
916 auto type = this->GetType();
917 auto chunked_array = ChunkedArrayFromJSON(
918 type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", null])"});
919 AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd,
920 "[4, 1, 0, 3, 2, 5]");
921 AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart,
922 "[2, 5, 4, 1, 0, 3]");
923 AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd,
924 "[3, 0, 1, 4, 2, 5]");
925 AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart,
926 "[2, 5, 3, 0, 1, 4]");
927 }
928
929 // Base class for testing against random chunked array.
930 template <typename Type>
931 class TestChunkedArrayRandomBase : public TestBase {
932 protected:
933 // Generates a chunk. This should be implemented in subclasses.
934 virtual std::shared_ptr<Array> GenerateArray(int length, double null_probability) = 0;
935
936 // All tests uses this.
TestSortIndices(int length)937 void TestSortIndices(int length) {
938 using ArrayType = typename TypeTraits<Type>::ArrayType;
939
940 for (auto null_probability : {0.0, 0.1, 0.5, 0.9, 1.0}) {
941 for (auto num_chunks : {1, 2, 5, 10, 40}) {
942 std::vector<std::shared_ptr<Array>> arrays;
943 for (int i = 0; i < num_chunks; ++i) {
944 auto array = this->GenerateArray(length / num_chunks, null_probability);
945 arrays.push_back(array);
946 }
947 ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(arrays));
948 // Concatenate chunks to use existing ValidateSorted() for array.
949 ASSERT_OK_AND_ASSIGN(auto concatenated_array, Concatenate(arrays));
950
951 for (auto order : AllOrders()) {
952 for (auto null_placement : AllNullPlacements()) {
953 ArraySortOptions options(order, null_placement);
954 ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(*chunked_array, options));
955 ValidateSorted<ArrayType>(
956 *checked_pointer_cast<ArrayType>(concatenated_array),
957 *checked_pointer_cast<UInt64Array>(offsets), order, null_placement);
958 }
959 }
960 }
961 }
962 }
963 };
964
965 // Long array with big value range: std::stable_sort
966 template <typename Type>
967 class TestChunkedArrayRandom : public TestChunkedArrayRandomBase<Type> {
968 public:
SetUp()969 void SetUp() override { rand_ = new Random<Type>(0x5487655); }
970
TearDown()971 void TearDown() override { delete rand_; }
972
973 protected:
GenerateArray(int length,double null_probability)974 std::shared_ptr<Array> GenerateArray(int length, double null_probability) override {
975 return rand_->Generate(length, null_probability);
976 }
977
978 private:
979 Random<Type>* rand_;
980 };
981 TYPED_TEST_SUITE(TestChunkedArrayRandom, SortIndicesableTypes);
982
TYPED_TEST(TestChunkedArrayRandom,SortIndices)983 TYPED_TEST(TestChunkedArrayRandom, SortIndices) { this->TestSortIndices(1000); }
984
985 // Long array with small value range: counting sort
986 // - length >= 1024(CountCompareSorter::countsort_min_len_)
987 // - range <= 4096(CountCompareSorter::countsort_max_range_)
988 template <typename Type>
989 class TestChunkedArrayRandomNarrow : public TestChunkedArrayRandomBase<Type> {
990 public:
SetUp()991 void SetUp() override {
992 range_ = 2000;
993 rand_ = new RandomRange<Type>(0x5487655);
994 }
995
TearDown()996 void TearDown() override { delete rand_; }
997
998 protected:
GenerateArray(int length,double null_probability)999 std::shared_ptr<Array> GenerateArray(int length, double null_probability) override {
1000 return rand_->Generate(length, range_, null_probability);
1001 }
1002
1003 private:
1004 int range_;
1005 RandomRange<Type>* rand_;
1006 };
1007 TYPED_TEST_SUITE(TestChunkedArrayRandomNarrow, IntegralArrowTypes);
TYPED_TEST(TestChunkedArrayRandomNarrow,SortIndices)1008 TYPED_TEST(TestChunkedArrayRandomNarrow, SortIndices) { this->TestSortIndices(1000); }
1009
1010 // Test basic cases for record batch.
1011 class TestRecordBatchSortIndices : public ::testing::Test {};
1012
TEST_F(TestRecordBatchSortIndices,NoNull)1013 TEST_F(TestRecordBatchSortIndices, NoNull) {
1014 auto schema = ::arrow::schema({
1015 {field("a", uint8())},
1016 {field("b", uint32())},
1017 });
1018 auto batch = RecordBatchFromJSON(schema,
1019 R"([{"a": 3, "b": 5},
1020 {"a": 1, "b": 3},
1021 {"a": 3, "b": 4},
1022 {"a": 0, "b": 6},
1023 {"a": 2, "b": 5},
1024 {"a": 1, "b": 5},
1025 {"a": 1, "b": 3}
1026 ])");
1027
1028 for (auto null_placement : AllNullPlacements()) {
1029 SortOptions options(
1030 {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)},
1031 null_placement);
1032
1033 AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]");
1034 }
1035 }
1036
TEST_F(TestRecordBatchSortIndices,Null)1037 TEST_F(TestRecordBatchSortIndices, Null) {
1038 auto schema = ::arrow::schema({
1039 {field("a", uint8())},
1040 {field("b", uint32())},
1041 });
1042 auto batch = RecordBatchFromJSON(schema,
1043 R"([{"a": null, "b": 5},
1044 {"a": 1, "b": 3},
1045 {"a": 3, "b": null},
1046 {"a": null, "b": null},
1047 {"a": 2, "b": 5},
1048 {"a": 1, "b": 5},
1049 {"a": 3, "b": 5}
1050 ])");
1051 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1052 SortKey("b", SortOrder::Descending)};
1053
1054 SortOptions options(sort_keys, NullPlacement::AtEnd);
1055 AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]");
1056 options.null_placement = NullPlacement::AtStart;
1057 AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]");
1058 }
1059
TEST_F(TestRecordBatchSortIndices,NaN)1060 TEST_F(TestRecordBatchSortIndices, NaN) {
1061 auto schema = ::arrow::schema({
1062 {field("a", float32())},
1063 {field("b", float64())},
1064 });
1065 auto batch = RecordBatchFromJSON(schema,
1066 R"([{"a": 3, "b": 5},
1067 {"a": 1, "b": NaN},
1068 {"a": 3, "b": 4},
1069 {"a": 0, "b": 6},
1070 {"a": NaN, "b": 5},
1071 {"a": NaN, "b": NaN},
1072 {"a": NaN, "b": 5},
1073 {"a": 1, "b": 5}
1074 ])");
1075 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1076 SortKey("b", SortOrder::Descending)};
1077
1078 SortOptions options(sort_keys, NullPlacement::AtEnd);
1079 AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
1080 options.null_placement = NullPlacement::AtStart;
1081 AssertSortIndices(batch, options, "[5, 4, 6, 3, 1, 7, 0, 2]");
1082 }
1083
TEST_F(TestRecordBatchSortIndices,NaNAndNull)1084 TEST_F(TestRecordBatchSortIndices, NaNAndNull) {
1085 auto schema = ::arrow::schema({
1086 {field("a", float32())},
1087 {field("b", float64())},
1088 });
1089 auto batch = RecordBatchFromJSON(schema,
1090 R"([{"a": null, "b": 5},
1091 {"a": 1, "b": 3},
1092 {"a": 3, "b": null},
1093 {"a": null, "b": null},
1094 {"a": NaN, "b": null},
1095 {"a": NaN, "b": NaN},
1096 {"a": NaN, "b": 5},
1097 {"a": 1, "b": 5}
1098 ])");
1099 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1100 SortKey("b", SortOrder::Descending)};
1101
1102 SortOptions options(sort_keys, NullPlacement::AtEnd);
1103 AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
1104 options.null_placement = NullPlacement::AtStart;
1105 AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
1106 }
1107
TEST_F(TestRecordBatchSortIndices,Boolean)1108 TEST_F(TestRecordBatchSortIndices, Boolean) {
1109 auto schema = ::arrow::schema({
1110 {field("a", boolean())},
1111 {field("b", boolean())},
1112 });
1113 auto batch = RecordBatchFromJSON(schema,
1114 R"([{"a": true, "b": null},
1115 {"a": false, "b": null},
1116 {"a": true, "b": true},
1117 {"a": false, "b": true},
1118 {"a": true, "b": false},
1119 {"a": null, "b": false},
1120 {"a": false, "b": null},
1121 {"a": null, "b": true}
1122 ])");
1123 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1124 SortKey("b", SortOrder::Descending)};
1125
1126 SortOptions options(sort_keys, NullPlacement::AtEnd);
1127 AssertSortIndices(batch, options, "[3, 1, 6, 2, 4, 0, 7, 5]");
1128 options.null_placement = NullPlacement::AtStart;
1129 AssertSortIndices(batch, options, "[7, 5, 1, 6, 3, 0, 2, 4]");
1130 }
1131
TEST_F(TestRecordBatchSortIndices,MoreTypes)1132 TEST_F(TestRecordBatchSortIndices, MoreTypes) {
1133 auto schema = ::arrow::schema({
1134 {field("a", timestamp(TimeUnit::MICRO))},
1135 {field("b", large_utf8())},
1136 {field("c", fixed_size_binary(3))},
1137 });
1138 auto batch = RecordBatchFromJSON(schema,
1139 R"([{"a": 3, "b": "05", "c": "aaa"},
1140 {"a": 1, "b": "031", "c": "bbb"},
1141 {"a": 3, "b": "05", "c": "bbb"},
1142 {"a": 0, "b": "0666", "c": "aaa"},
1143 {"a": 2, "b": "05", "c": "aaa"},
1144 {"a": 1, "b": "05", "c": "bbb"}
1145 ])");
1146 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1147 SortKey("b", SortOrder::Descending),
1148 SortKey("c", SortOrder::Ascending)};
1149
1150 for (auto null_placement : AllNullPlacements()) {
1151 SortOptions options(sort_keys, null_placement);
1152 AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]");
1153 }
1154 }
1155
TEST_F(TestRecordBatchSortIndices,Decimal)1156 TEST_F(TestRecordBatchSortIndices, Decimal) {
1157 auto schema = ::arrow::schema({
1158 {field("a", decimal128(3, 1))},
1159 {field("b", decimal256(4, 2))},
1160 });
1161 auto batch = RecordBatchFromJSON(schema,
1162 R"([{"a": "12.3", "b": "12.34"},
1163 {"a": "45.6", "b": "12.34"},
1164 {"a": "12.3", "b": "-12.34"},
1165 {"a": "-12.3", "b": null},
1166 {"a": "-12.3", "b": "-45.67"}
1167 ])");
1168 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1169 SortKey("b", SortOrder::Descending)};
1170
1171 SortOptions options(sort_keys, NullPlacement::AtEnd);
1172 AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]");
1173 options.null_placement = NullPlacement::AtStart;
1174 AssertSortIndices(batch, options, "[3, 4, 0, 2, 1]");
1175 }
1176
TEST_F(TestRecordBatchSortIndices,NullType)1177 TEST_F(TestRecordBatchSortIndices, NullType) {
1178 auto schema = arrow::schema({
1179 field("a", null()),
1180 field("b", int32()),
1181 field("c", int32()),
1182 field("d", int32()),
1183 field("e", int32()),
1184 field("f", int32()),
1185 field("g", int32()),
1186 field("h", int32()),
1187 field("i", null()),
1188 });
1189 auto batch = RecordBatchFromJSON(schema, R"([
1190 {"a": null, "b": 5, "c": 0, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null},
1191 {"a": null, "b": 5, "c": 1, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null},
1192 {"a": null, "b": 2, "c": 2, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null},
1193 {"a": null, "b": 4, "c": 3, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null}
1194 ])");
1195 for (const auto null_placement : AllNullPlacements()) {
1196 for (const auto order : AllOrders()) {
1197 // Uses radix sorter
1198 AssertSortIndices(batch,
1199 SortOptions(
1200 {
1201 SortKey("a", order),
1202 SortKey("i", order),
1203 },
1204 null_placement),
1205 "[0, 1, 2, 3]");
1206 AssertSortIndices(batch,
1207 SortOptions(
1208 {
1209 SortKey("a", order),
1210 SortKey("b", SortOrder::Ascending),
1211 SortKey("i", order),
1212 },
1213 null_placement),
1214 "[2, 3, 0, 1]");
1215 // Uses multiple-key sorter
1216 AssertSortIndices(batch,
1217 SortOptions(
1218 {
1219 SortKey("a", order),
1220 SortKey("b", SortOrder::Ascending),
1221 SortKey("c", SortOrder::Ascending),
1222 SortKey("d", SortOrder::Ascending),
1223 SortKey("e", SortOrder::Ascending),
1224 SortKey("f", SortOrder::Ascending),
1225 SortKey("g", SortOrder::Ascending),
1226 SortKey("h", SortOrder::Ascending),
1227 SortKey("i", order),
1228 },
1229 null_placement),
1230 "[2, 3, 0, 1]");
1231 }
1232 }
1233 }
1234
TEST_F(TestRecordBatchSortIndices,DuplicateSortKeys)1235 TEST_F(TestRecordBatchSortIndices, DuplicateSortKeys) {
1236 // ARROW-14073: only the first occurrence of a given sort column is taken
1237 // into account.
1238 auto schema = ::arrow::schema({
1239 {field("a", float32())},
1240 {field("b", float64())},
1241 });
1242 auto batch = RecordBatchFromJSON(schema,
1243 R"([{"a": null, "b": 5},
1244 {"a": 1, "b": 3},
1245 {"a": 3, "b": null},
1246 {"a": null, "b": null},
1247 {"a": NaN, "b": null},
1248 {"a": NaN, "b": NaN},
1249 {"a": NaN, "b": 5},
1250 {"a": 1, "b": 5}
1251 ])");
1252 const std::vector<SortKey> sort_keys{
1253 SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending),
1254 SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending),
1255 SortKey("a", SortOrder::Descending)};
1256
1257 SortOptions options(sort_keys, NullPlacement::AtEnd);
1258 AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
1259 options.null_placement = NullPlacement::AtStart;
1260 AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
1261 }
1262
1263 // Test basic cases for table.
1264 class TestTableSortIndices : public ::testing::Test {};
1265
TEST_F(TestTableSortIndices,EmptyTable)1266 TEST_F(TestTableSortIndices, EmptyTable) {
1267 auto schema = ::arrow::schema({
1268 {field("a", uint8())},
1269 {field("b", uint32())},
1270 });
1271 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1272 SortKey("b", SortOrder::Descending)};
1273
1274 auto table = TableFromJSON(schema, {"[]"});
1275 auto chunked_table = TableFromJSON(schema, {"[]", "[]"});
1276
1277 SortOptions options(sort_keys, NullPlacement::AtEnd);
1278 AssertSortIndices(table, options, "[]");
1279 AssertSortIndices(chunked_table, options, "[]");
1280 options.null_placement = NullPlacement::AtStart;
1281 AssertSortIndices(table, options, "[]");
1282 AssertSortIndices(chunked_table, options, "[]");
1283 }
1284
TEST_F(TestTableSortIndices,EmptySortKeys)1285 TEST_F(TestTableSortIndices, EmptySortKeys) {
1286 auto schema = ::arrow::schema({
1287 {field("a", uint8())},
1288 {field("b", uint32())},
1289 });
1290 const std::vector<SortKey> sort_keys{};
1291 const SortOptions options(sort_keys, NullPlacement::AtEnd);
1292
1293 auto table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])"});
1294 EXPECT_RAISES_WITH_MESSAGE_THAT(
1295 Invalid, testing::HasSubstr("Must specify one or more sort keys"),
1296 CallFunction("sort_indices", {table}, &options));
1297
1298 // Several chunks
1299 table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])", R"([{"a": 0, "b": 6}])"});
1300 EXPECT_RAISES_WITH_MESSAGE_THAT(
1301 Invalid, testing::HasSubstr("Must specify one or more sort keys"),
1302 CallFunction("sort_indices", {table}, &options));
1303 }
1304
TEST_F(TestTableSortIndices,Null)1305 TEST_F(TestTableSortIndices, Null) {
1306 auto schema = ::arrow::schema({
1307 {field("a", uint8())},
1308 {field("b", uint32())},
1309 });
1310 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1311 SortKey("b", SortOrder::Descending)};
1312 std::shared_ptr<Table> table;
1313
1314 table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
1315 {"a": 1, "b": 3},
1316 {"a": 3, "b": null},
1317 {"a": null, "b": null},
1318 {"a": 2, "b": 5},
1319 {"a": 1, "b": 5},
1320 {"a": 3, "b": 5}
1321 ])"});
1322 SortOptions options(sort_keys, NullPlacement::AtEnd);
1323 AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]");
1324 options.null_placement = NullPlacement::AtStart;
1325 AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]");
1326
1327 // Same data, several chunks
1328 table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
1329 {"a": 1, "b": 3},
1330 {"a": 3, "b": null}
1331 ])",
1332 R"([{"a": null, "b": null},
1333 {"a": 2, "b": 5},
1334 {"a": 1, "b": 5},
1335 {"a": 3, "b": 5}
1336 ])"});
1337 options.null_placement = NullPlacement::AtEnd;
1338 AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]");
1339 options.null_placement = NullPlacement::AtStart;
1340 AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]");
1341 }
1342
TEST_F(TestTableSortIndices,NaN)1343 TEST_F(TestTableSortIndices, NaN) {
1344 auto schema = ::arrow::schema({
1345 {field("a", float32())},
1346 {field("b", float64())},
1347 });
1348 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1349 SortKey("b", SortOrder::Descending)};
1350 std::shared_ptr<Table> table;
1351
1352 table = TableFromJSON(schema, {R"([{"a": 3, "b": 5},
1353 {"a": 1, "b": NaN},
1354 {"a": 3, "b": 4},
1355 {"a": 0, "b": 6},
1356 {"a": NaN, "b": 5},
1357 {"a": NaN, "b": NaN},
1358 {"a": NaN, "b": 5},
1359 {"a": 1, "b": 5}
1360 ])"});
1361 SortOptions options(sort_keys, NullPlacement::AtEnd);
1362 AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
1363 options.null_placement = NullPlacement::AtStart;
1364 AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]");
1365
1366 // Same data, several chunks
1367 table = TableFromJSON(schema, {R"([{"a": 3, "b": 5},
1368 {"a": 1, "b": NaN},
1369 {"a": 3, "b": 4},
1370 {"a": 0, "b": 6}
1371 ])",
1372 R"([{"a": NaN, "b": 5},
1373 {"a": NaN, "b": NaN},
1374 {"a": NaN, "b": 5},
1375 {"a": 1, "b": 5}
1376 ])"});
1377 options.null_placement = NullPlacement::AtEnd;
1378 AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
1379 options.null_placement = NullPlacement::AtStart;
1380 AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]");
1381 }
1382
TEST_F(TestTableSortIndices,NaNAndNull)1383 TEST_F(TestTableSortIndices, NaNAndNull) {
1384 auto schema = ::arrow::schema({
1385 {field("a", float32())},
1386 {field("b", float64())},
1387 });
1388 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1389 SortKey("b", SortOrder::Descending)};
1390 std::shared_ptr<Table> table;
1391
1392 table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
1393 {"a": 1, "b": 3},
1394 {"a": 3, "b": null},
1395 {"a": null, "b": null},
1396 {"a": NaN, "b": null},
1397 {"a": NaN, "b": NaN},
1398 {"a": NaN, "b": 5},
1399 {"a": 1, "b": 5}
1400 ])"});
1401 SortOptions options(sort_keys, NullPlacement::AtEnd);
1402 AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
1403 options.null_placement = NullPlacement::AtStart;
1404 AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
1405
1406 // Same data, several chunks
1407 table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
1408 {"a": 1, "b": 3},
1409 {"a": 3, "b": null},
1410 {"a": null, "b": null}
1411 ])",
1412 R"([{"a": NaN, "b": null},
1413 {"a": NaN, "b": NaN},
1414 {"a": NaN, "b": 5},
1415 {"a": 1, "b": 5}
1416 ])"});
1417 options.null_placement = NullPlacement::AtEnd;
1418 AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
1419 options.null_placement = NullPlacement::AtStart;
1420 AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
1421 }
1422
TEST_F(TestTableSortIndices,Boolean)1423 TEST_F(TestTableSortIndices, Boolean) {
1424 auto schema = ::arrow::schema({
1425 {field("a", boolean())},
1426 {field("b", boolean())},
1427 });
1428 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1429 SortKey("b", SortOrder::Descending)};
1430
1431 auto table = TableFromJSON(schema, {R"([{"a": true, "b": null},
1432 {"a": false, "b": null},
1433 {"a": true, "b": true},
1434 {"a": false, "b": true}
1435 ])",
1436 R"([{"a": true, "b": false},
1437 {"a": null, "b": false},
1438 {"a": false, "b": null},
1439 {"a": null, "b": true}
1440 ])"});
1441 SortOptions options(sort_keys, NullPlacement::AtEnd);
1442 AssertSortIndices(table, options, "[3, 1, 6, 2, 4, 0, 7, 5]");
1443 options.null_placement = NullPlacement::AtStart;
1444 AssertSortIndices(table, options, "[7, 5, 1, 6, 3, 0, 2, 4]");
1445 }
1446
TEST_F(TestTableSortIndices,BinaryLike)1447 TEST_F(TestTableSortIndices, BinaryLike) {
1448 auto schema = ::arrow::schema({
1449 {field("a", large_utf8())},
1450 {field("b", fixed_size_binary(3))},
1451 });
1452 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Descending),
1453 SortKey("b", SortOrder::Ascending)};
1454
1455 auto table = TableFromJSON(schema, {R"([{"a": "one", "b": null},
1456 {"a": "two", "b": "aaa"},
1457 {"a": "three", "b": "bbb"},
1458 {"a": "four", "b": "ccc"}
1459 ])",
1460 R"([{"a": "one", "b": "ddd"},
1461 {"a": "two", "b": "ccc"},
1462 {"a": "three", "b": "bbb"},
1463 {"a": "four", "b": "aaa"}
1464 ])"});
1465 SortOptions options(sort_keys, NullPlacement::AtEnd);
1466 AssertSortIndices(table, options, "[1, 5, 2, 6, 4, 0, 7, 3]");
1467 options.null_placement = NullPlacement::AtStart;
1468 AssertSortIndices(table, options, "[1, 5, 2, 6, 0, 4, 7, 3]");
1469 }
1470
TEST_F(TestTableSortIndices,Decimal)1471 TEST_F(TestTableSortIndices, Decimal) {
1472 auto schema = ::arrow::schema({
1473 {field("a", decimal128(3, 1))},
1474 {field("b", decimal256(4, 2))},
1475 });
1476 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1477 SortKey("b", SortOrder::Descending)};
1478
1479 auto table = TableFromJSON(schema, {R"([{"a": "12.3", "b": "12.34"},
1480 {"a": "45.6", "b": "12.34"},
1481 {"a": "12.3", "b": "-12.34"}
1482 ])",
1483 R"([{"a": "-12.3", "b": null},
1484 {"a": "-12.3", "b": "-45.67"}
1485 ])"});
1486 SortOptions options(sort_keys, NullPlacement::AtEnd);
1487 AssertSortIndices(table, options, "[4, 3, 0, 2, 1]");
1488 options.null_placement = NullPlacement::AtStart;
1489 AssertSortIndices(table, options, "[3, 4, 0, 2, 1]");
1490 }
1491
TEST_F(TestTableSortIndices,NullType)1492 TEST_F(TestTableSortIndices, NullType) {
1493 auto schema = arrow::schema({
1494 field("a", null()),
1495 field("b", int32()),
1496 field("c", int32()),
1497 field("d", null()),
1498 });
1499 auto table = TableFromJSON(schema, {
1500 R"([
1501 {"a": null, "b": 5, "c": 0, "d": null},
1502 {"a": null, "b": 5, "c": 1, "d": null},
1503 {"a": null, "b": 2, "c": 2, "d": null}
1504 ])",
1505 R"([])",
1506 R"([{"a": null, "b": 4, "c": 3, "d": null}])",
1507 });
1508 for (const auto null_placement : AllNullPlacements()) {
1509 for (const auto order : AllOrders()) {
1510 AssertSortIndices(table,
1511 SortOptions(
1512 {
1513 SortKey("a", order),
1514 SortKey("d", order),
1515 },
1516 null_placement),
1517 "[0, 1, 2, 3]");
1518 AssertSortIndices(table,
1519 SortOptions(
1520 {
1521 SortKey("a", order),
1522 SortKey("b", SortOrder::Ascending),
1523 SortKey("d", order),
1524 },
1525 null_placement),
1526 "[2, 3, 0, 1]");
1527 }
1528 }
1529 }
1530
TEST_F(TestTableSortIndices,DuplicateSortKeys)1531 TEST_F(TestTableSortIndices, DuplicateSortKeys) {
1532 // ARROW-14073: only the first occurrence of a given sort column is taken
1533 // into account.
1534 auto schema = ::arrow::schema({
1535 {field("a", float32())},
1536 {field("b", float64())},
1537 });
1538 const std::vector<SortKey> sort_keys{
1539 SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending),
1540 SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending),
1541 SortKey("a", SortOrder::Descending)};
1542 std::shared_ptr<Table> table;
1543
1544 table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
1545 {"a": 1, "b": 3},
1546 {"a": 3, "b": null},
1547 {"a": null, "b": null}
1548 ])",
1549 R"([{"a": NaN, "b": null},
1550 {"a": NaN, "b": NaN},
1551 {"a": NaN, "b": 5},
1552 {"a": 1, "b": 5}
1553 ])"});
1554 SortOptions options(sort_keys, NullPlacement::AtEnd);
1555 AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
1556 options.null_placement = NullPlacement::AtStart;
1557 AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
1558 }
1559
TEST_F(TestTableSortIndices,HeterogenousChunking)1560 TEST_F(TestTableSortIndices, HeterogenousChunking) {
1561 auto schema = ::arrow::schema({
1562 {field("a", float32())},
1563 {field("b", float64())},
1564 });
1565
1566 // Same logical data as in "NaNAndNull" test above
1567 auto col_a =
1568 ChunkedArrayFromJSON(float32(), {"[null, 1]", "[]", "[3, null, NaN, NaN, NaN, 1]"});
1569 auto col_b = ChunkedArrayFromJSON(float64(),
1570 {"[5]", "[3, null, null]", "[null, NaN, 5]", "[5]"});
1571 auto table = Table::Make(schema, {col_a, col_b});
1572
1573 SortOptions options(
1574 {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
1575 AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
1576 options.null_placement = NullPlacement::AtStart;
1577 AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
1578
1579 options = SortOptions(
1580 {SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)});
1581 AssertSortIndices(table, options, "[1, 7, 6, 0, 5, 2, 4, 3]");
1582 options.null_placement = NullPlacement::AtStart;
1583 AssertSortIndices(table, options, "[3, 4, 2, 5, 1, 0, 6, 7]");
1584 }
1585
1586 // Tests for temporal types
1587 template <typename ArrowType>
1588 class TestTableSortIndicesForTemporal : public TestTableSortIndices {
1589 protected:
GetType()1590 std::shared_ptr<DataType> GetType() { return default_type_instance<ArrowType>(); }
1591 };
1592 TYPED_TEST_SUITE(TestTableSortIndicesForTemporal, TemporalArrowTypes);
1593
TYPED_TEST(TestTableSortIndicesForTemporal,NoNull)1594 TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) {
1595 auto type = this->GetType();
1596 const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
1597 SortKey("b", SortOrder::Descending)};
1598 auto table = TableFromJSON(schema({
1599 {field("a", type)},
1600 {field("b", type)},
1601 }),
1602 {R"([{"a": 0, "b": 5},
1603 {"a": 1, "b": 3},
1604 {"a": 3, "b": 0},
1605 {"a": 2, "b": 1},
1606 {"a": 1, "b": 3},
1607 {"a": 5, "b": 0},
1608 {"a": 0, "b": 4},
1609 {"a": 1, "b": 2}
1610 ])"});
1611 for (auto null_placement : AllNullPlacements()) {
1612 SortOptions options(sort_keys, null_placement);
1613 AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]");
1614 }
1615 }
1616
1617 // For random table tests.
1618 using RandomParam = std::tuple<std::string, int, double>;
1619
1620 class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
1621 // Compares two records in a column
1622 class ColumnComparator : public TypeVisitor {
1623 public:
ColumnComparator(SortOrder order,NullPlacement null_placement)1624 ColumnComparator(SortOrder order, NullPlacement null_placement)
1625 : order_(order), null_placement_(null_placement) {}
1626
operator ()(const Array & left,const Array & right,uint64_t lhs,uint64_t rhs)1627 int operator()(const Array& left, const Array& right, uint64_t lhs, uint64_t rhs) {
1628 left_ = &left;
1629 right_ = &right;
1630 lhs_ = lhs;
1631 rhs_ = rhs;
1632 ARROW_CHECK_OK(left.type()->Accept(this));
1633 return compared_;
1634 }
1635
1636 #define VISIT(TYPE) \
1637 Status Visit(const TYPE##Type& type) override { \
1638 compared_ = CompareType<TYPE##Type>(); \
1639 return Status::OK(); \
1640 }
1641
1642 VISIT(Boolean)
VISIT(Int8)1643 VISIT(Int8)
1644 VISIT(Int16)
1645 VISIT(Int32)
1646 VISIT(Int64)
1647 VISIT(UInt8)
1648 VISIT(UInt16)
1649 VISIT(UInt32)
1650 VISIT(UInt64)
1651 VISIT(Float)
1652 VISIT(Double)
1653 VISIT(String)
1654 VISIT(LargeString)
1655 VISIT(Decimal128)
1656 VISIT(Decimal256)
1657
1658 #undef VISIT
1659
1660 template <typename Type>
1661 int CompareType() {
1662 using ArrayType = typename TypeTraits<Type>::ArrayType;
1663 ThreeWayComparator<ArrayType> three_way{order_, null_placement_};
1664 return three_way(checked_cast<const ArrayType&>(*left_),
1665 checked_cast<const ArrayType&>(*right_), lhs_, rhs_);
1666 }
1667
1668 const SortOrder order_;
1669 const NullPlacement null_placement_;
1670 const Array* left_;
1671 const Array* right_;
1672 uint64_t lhs_;
1673 uint64_t rhs_;
1674 int compared_;
1675 };
1676
1677 // Compares two records in the same table.
1678 class Comparator {
1679 public:
Comparator(const Table & table,const SortOptions & options)1680 Comparator(const Table& table, const SortOptions& options) : options_(options) {
1681 for (const auto& sort_key : options_.sort_keys) {
1682 sort_columns_.emplace_back(table.GetColumnByName(sort_key.name).get(),
1683 sort_key.order);
1684 }
1685 }
1686
1687 // Return true if the left record is less or equals to the right record,
1688 // false otherwise.
operator ()(uint64_t lhs,uint64_t rhs)1689 bool operator()(uint64_t lhs, uint64_t rhs) {
1690 for (const auto& pair : sort_columns_) {
1691 ColumnComparator comparator(pair.second, options_.null_placement);
1692 const auto& chunked_array = *pair.first;
1693 int64_t lhs_index = 0, rhs_index = 0;
1694 const Array* lhs_array = FindTargetArray(chunked_array, lhs, &lhs_index);
1695 const Array* rhs_array = FindTargetArray(chunked_array, rhs, &rhs_index);
1696 int compared = comparator(*lhs_array, *rhs_array, lhs_index, rhs_index);
1697 if (compared != 0) {
1698 return compared < 0;
1699 }
1700 }
1701 return lhs < rhs;
1702 }
1703
1704 // Find the target chunk and index in the target chunk from an
1705 // index in chunked array.
FindTargetArray(const ChunkedArray & chunked_array,int64_t i,int64_t * chunk_index)1706 const Array* FindTargetArray(const ChunkedArray& chunked_array, int64_t i,
1707 int64_t* chunk_index) {
1708 int64_t offset = 0;
1709 for (const auto& chunk : chunked_array.chunks()) {
1710 if (i < offset + chunk->length()) {
1711 *chunk_index = i - offset;
1712 return chunk.get();
1713 }
1714 offset += chunk->length();
1715 }
1716 return nullptr;
1717 }
1718
1719 const SortOptions& options_;
1720 std::vector<std::pair<const ChunkedArray*, SortOrder>> sort_columns_;
1721 };
1722
1723 public:
1724 // Validates the sorted indices are really sorted.
Validate(const Table & table,const SortOptions & options,UInt64Array & offsets)1725 void Validate(const Table& table, const SortOptions& options, UInt64Array& offsets) {
1726 ValidateOutput(offsets);
1727 Comparator comparator{table, options};
1728 for (int i = 1; i < table.num_rows(); i++) {
1729 uint64_t lhs = offsets.Value(i - 1);
1730 uint64_t rhs = offsets.Value(i);
1731 if (!comparator(lhs, rhs)) {
1732 std::stringstream ss;
1733 ss << "Rows not ordered at consecutive sort indices:";
1734 ss << "\nFirst row (index = " << lhs << "): ";
1735 PrintRow(table, lhs, &ss);
1736 ss << "\nSecond row (index = " << rhs << "): ";
1737 PrintRow(table, rhs, &ss);
1738 FAIL() << ss.str();
1739 }
1740 }
1741 }
1742
PrintRow(const Table & table,uint64_t index,std::ostream * os)1743 void PrintRow(const Table& table, uint64_t index, std::ostream* os) {
1744 *os << "{";
1745 const auto& columns = table.columns();
1746 for (size_t i = 0; i < columns.size(); ++i) {
1747 if (i != 0) {
1748 *os << ", ";
1749 }
1750 ASSERT_OK_AND_ASSIGN(auto scal, columns[i]->GetScalar(index));
1751 *os << scal->ToString();
1752 }
1753 *os << "}";
1754 }
1755 };
1756
TEST_P(TestTableSortIndicesRandom,Sort)1757 TEST_P(TestTableSortIndicesRandom, Sort) {
1758 const auto first_sort_key_name = std::get<0>(GetParam());
1759 const auto n_sort_keys = std::get<1>(GetParam());
1760 const auto null_probability = std::get<2>(GetParam());
1761 const auto nan_probability = (1.0 - null_probability) / 4;
1762 const auto seed = 0x61549225;
1763
1764 ARROW_SCOPED_TRACE("n_sort_keys = ", n_sort_keys);
1765 ARROW_SCOPED_TRACE("null_probability = ", null_probability);
1766
1767 ::arrow::random::RandomArrayGenerator rng(seed);
1768
1769 // Of these, "uint8", "boolean" and "string" should have many duplicates
1770 const FieldVector fields = {
1771 {field("uint8", uint8())},
1772 {field("int16", int16())},
1773 {field("int32", int32())},
1774 {field("uint64", uint64())},
1775 {field("float", float32())},
1776 {field("boolean", boolean())},
1777 {field("string", utf8())},
1778 {field("large_string", large_utf8())},
1779 {field("decimal128", decimal128(25, 3))},
1780 {field("decimal256", decimal256(42, 6))},
1781 };
1782 const auto schema = ::arrow::schema(fields);
1783 const int64_t length = 80;
1784
1785 using ArrayFactory = std::function<std::shared_ptr<Array>(int64_t length)>;
1786
1787 std::vector<ArrayFactory> column_factories{
1788 [&](int64_t length) { return rng.UInt8(length, 0, 10, null_probability); },
1789 [&](int64_t length) {
1790 return rng.Int16(length, -1000, 12000, /*null_probability=*/0.0);
1791 },
1792 [&](int64_t length) {
1793 return rng.Int32(length, -123456789, 987654321, null_probability);
1794 },
1795 [&](int64_t length) {
1796 return rng.UInt64(length, 1, 1234567890123456789ULL, /*null_probability=*/0.0);
1797 },
1798 [&](int64_t length) {
1799 return rng.Float32(length, -1.0f, 1.0f, null_probability, nan_probability);
1800 },
1801 [&](int64_t length) {
1802 return rng.Boolean(length, /*true_probability=*/0.3, null_probability);
1803 },
1804 [&](int64_t length) {
1805 if (length > 0) {
1806 return rng.StringWithRepeats(length, /*unique=*/1 + length / 10,
1807 /*min_length=*/5,
1808 /*max_length=*/15, null_probability);
1809 } else {
1810 return *MakeArrayOfNull(utf8(), 0);
1811 }
1812 },
1813 [&](int64_t length) {
1814 return rng.LargeString(length, /*min_length=*/5, /*max_length=*/15,
1815 /*null_probability=*/0.0);
1816 },
1817 [&](int64_t length) {
1818 return rng.Decimal128(fields[8]->type(), length, null_probability);
1819 },
1820 [&](int64_t length) {
1821 return rng.Decimal256(fields[9]->type(), length, /*null_probability=*/0.0);
1822 },
1823 };
1824
1825 // Generate random sort keys, making sure no column is included twice
1826 std::default_random_engine engine(seed);
1827 std::uniform_int_distribution<> distribution(0);
1828
1829 auto generate_order = [&]() {
1830 return (distribution(engine) & 1) ? SortOrder::Ascending : SortOrder::Descending;
1831 };
1832
1833 std::vector<SortKey> sort_keys;
1834 sort_keys.reserve(fields.size());
1835 for (const auto& field : fields) {
1836 if (field->name() != first_sort_key_name) {
1837 sort_keys.emplace_back(field->name(), generate_order());
1838 }
1839 }
1840 std::shuffle(sort_keys.begin(), sort_keys.end(), engine);
1841 sort_keys.emplace(sort_keys.begin(), first_sort_key_name, generate_order());
1842 sort_keys.erase(sort_keys.begin() + n_sort_keys, sort_keys.end());
1843 ASSERT_EQ(sort_keys.size(), n_sort_keys);
1844
1845 std::stringstream ss;
1846 for (const auto& sort_key : sort_keys) {
1847 ss << sort_key.name << (sort_key.order == SortOrder::Ascending ? " ASC" : " DESC");
1848 ss << ", ";
1849 }
1850 ARROW_SCOPED_TRACE("sort_keys = ", ss.str());
1851
1852 SortOptions options(sort_keys);
1853
1854 // Test with different, heterogenous table chunkings
1855 for (const int64_t max_num_chunks : {1, 3, 15}) {
1856 ARROW_SCOPED_TRACE("Table sorting: max chunks per column = ", max_num_chunks);
1857 std::uniform_int_distribution<int64_t> num_chunk_dist(1 + max_num_chunks / 2,
1858 max_num_chunks);
1859 ChunkedArrayVector columns;
1860 columns.reserve(fields.size());
1861
1862 // Chunk each column independently, and make sure they consist of
1863 // physically non-contiguous chunks.
1864 for (const auto& factory : column_factories) {
1865 const int64_t num_chunks = num_chunk_dist(engine);
1866 ArrayVector chunks(num_chunks);
1867 const auto offsets =
1868 checked_pointer_cast<Int32Array>(rng.Offsets(num_chunks + 1, 0, length));
1869 for (int64_t i = 0; i < num_chunks; ++i) {
1870 const auto chunk_len = offsets->Value(i + 1) - offsets->Value(i);
1871 chunks[i] = factory(chunk_len);
1872 }
1873 columns.push_back(std::make_shared<ChunkedArray>(std::move(chunks)));
1874 ASSERT_EQ(columns.back()->length(), length);
1875 }
1876
1877 auto table = Table::Make(schema, std::move(columns));
1878 for (auto null_placement : AllNullPlacements()) {
1879 ARROW_SCOPED_TRACE("null_placement = ", null_placement);
1880 options.null_placement = null_placement;
1881 ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*table), options));
1882 Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
1883 }
1884 }
1885
1886 // Also validate RecordBatch sorting
1887 ARROW_SCOPED_TRACE("Record batch sorting");
1888 ArrayVector columns;
1889 columns.reserve(fields.size());
1890 for (const auto& factory : column_factories) {
1891 columns.push_back(factory(length));
1892 }
1893 auto batch = RecordBatch::Make(schema, length, std::move(columns));
1894 ASSERT_OK(batch->ValidateFull());
1895 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(schema, {batch}));
1896
1897 for (auto null_placement : AllNullPlacements()) {
1898 ARROW_SCOPED_TRACE("null_placement = ", null_placement);
1899 options.null_placement = null_placement;
1900 ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(batch), options));
1901 Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
1902 }
1903 }
1904
1905 // Some first keys will have duplicates, others not
1906 static const auto first_sort_keys = testing::Values("uint8", "int16", "uint64", "float",
1907 "boolean", "string", "decimal128");
1908
1909 // Different numbers of sort keys may trigger different algorithms
1910 static const auto num_sort_keys = testing::Values(1, 3, 7, 9);
1911
1912 INSTANTIATE_TEST_SUITE_P(NoNull, TestTableSortIndicesRandom,
1913 testing::Combine(first_sort_keys, num_sort_keys,
1914 testing::Values(0.0)));
1915
1916 INSTANTIATE_TEST_SUITE_P(SomeNulls, TestTableSortIndicesRandom,
1917 testing::Combine(first_sort_keys, num_sort_keys,
1918 testing::Values(0.1, 0.5)));
1919
1920 INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom,
1921 testing::Combine(first_sort_keys, num_sort_keys,
1922 testing::Values(1.0)));
1923
1924 } // namespace compute
1925 } // namespace arrow
1926