1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <gtest/gtest.h>
19
20 #include "arrow/chunked_array.h"
21 #include "arrow/compute/api.h"
22 #include "arrow/compute/kernels/test_util.h"
23 #include "arrow/result.h"
24 #include "arrow/testing/gtest_util.h"
25 #include "arrow/testing/matchers.h"
26 #include "arrow/util/key_value_metadata.h"
27
28 namespace arrow {
29 namespace compute {
30
GetOffsetType(const DataType & type)31 static std::shared_ptr<DataType> GetOffsetType(const DataType& type) {
32 return type.id() == Type::LIST ? int32() : int64();
33 }
34
TEST(TestScalarNested,ListValueLength)35 TEST(TestScalarNested, ListValueLength) {
36 for (auto ty : {list(int32()), large_list(int32())}) {
37 CheckScalarUnary("list_value_length", ty, "[[0, null, 1], null, [2, 3], []]",
38 GetOffsetType(*ty), "[3, null, 2, 0]");
39 }
40
41 CheckScalarUnary("list_value_length", fixed_size_list(int32(), 3),
42 "[[0, null, 1], null, [2, 3, 4], [1, 2, null]]", int32(),
43 "[3, null, 3, 3]");
44 }
45
TEST(TestScalarNested,ListElementNonFixedListWithNulls)46 TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
47 auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]";
48 for (auto ty : NumericTypes()) {
49 for (auto list_type : {list(ty), large_list(ty)}) {
50 auto input = ArrayFromJSON(list_type, sample);
51 auto null_input = ArrayFromJSON(list_type, "[null]");
52 for (auto index_type : IntTypes()) {
53 auto index = ScalarFromJSON(index_type, "1");
54 auto expected = ArrayFromJSON(ty, "[5, null, 12, 9, null]");
55 auto expected_null = ArrayFromJSON(ty, "[null]");
56 CheckScalar("list_element", {input, index}, expected);
57 CheckScalar("list_element", {null_input, index}, expected_null);
58 }
59 }
60 }
61 }
62
TEST(TestScalarNested,ListElementFixedList)63 TEST(TestScalarNested, ListElementFixedList) {
64 auto sample = "[[7, 5, 81], [6, 4, 8], [3, 12, 2], [1, 43, 87]]";
65 for (auto ty : NumericTypes()) {
66 auto input = ArrayFromJSON(fixed_size_list(ty, 3), sample);
67 for (auto index_type : IntTypes()) {
68 auto index = ScalarFromJSON(index_type, "0");
69 auto expected = ArrayFromJSON(ty, "[7, 6, 3, 1]");
70 CheckScalar("list_element", {input, index}, expected);
71 }
72 }
73 }
74
TEST(TestScalarNested,ListElementInvalid)75 TEST(TestScalarNested, ListElementInvalid) {
76 auto input_array = ArrayFromJSON(list(float32()), "[[0.1, 1.1], [0.2, 1.2]]");
77 auto input_scalar = ScalarFromJSON(list(float32()), "[0.1, 0.2]");
78
79 // invalid index: null
80 auto index = ScalarFromJSON(int32(), "null");
81 EXPECT_THAT(CallFunction("list_element", {input_array, index}),
82 Raises(StatusCode::Invalid));
83 EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
84 Raises(StatusCode::Invalid));
85
86 // invalid index: < 0
87 index = ScalarFromJSON(int32(), "-1");
88 EXPECT_THAT(CallFunction("list_element", {input_array, index}),
89 Raises(StatusCode::Invalid));
90 EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
91 Raises(StatusCode::Invalid));
92
93 // invalid index: >= list.length
94 index = ScalarFromJSON(int32(), "2");
95 EXPECT_THAT(CallFunction("list_element", {input_array, index}),
96 Raises(StatusCode::Invalid));
97 EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
98 Raises(StatusCode::Invalid));
99
100 // invalid input
101 input_array = ArrayFromJSON(list(float32()), "[[41, 6, 93], [], [2]]");
102 input_scalar = ScalarFromJSON(list(float32()), "[]");
103 index = ScalarFromJSON(int32(), "0");
104 EXPECT_THAT(CallFunction("list_element", {input_array, index}),
105 Raises(StatusCode::Invalid));
106 EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
107 Raises(StatusCode::Invalid));
108 }
109
110 struct {
operator ()arrow::compute::__anon306313010108111 Result<Datum> operator()(std::vector<Datum> args) {
112 return CallFunction("make_struct", args);
113 }
114
115 template <typename... Options>
operator ()arrow::compute::__anon306313010108116 Result<Datum> operator()(std::vector<Datum> args, std::vector<std::string> field_names,
117 Options... options) {
118 MakeStructOptions opts{field_names, options...};
119 return CallFunction("make_struct", args, &opts);
120 }
121 } MakeStruct;
122
TEST(MakeStruct,Scalar)123 TEST(MakeStruct, Scalar) {
124 auto i32 = MakeScalar(1);
125 auto f64 = MakeScalar(2.5);
126 auto str = MakeScalar("yo");
127
128 EXPECT_THAT(MakeStruct({i32, f64, str}, {"i", "f", "s"}),
129 ResultWith(Datum(*StructScalar::Make({i32, f64, str}, {"i", "f", "s"}))));
130
131 // Names default to field_index
132 EXPECT_THAT(MakeStruct({i32, f64, str}),
133 ResultWith(Datum(*StructScalar::Make({i32, f64, str}, {"0", "1", "2"}))));
134
135 // No field names or input values is fine
136 EXPECT_THAT(MakeStruct({}), ResultWith(Datum(*StructScalar::Make({}, {}))));
137
138 // Three field names but one input value
139 EXPECT_THAT(MakeStruct({str}, {"i", "f", "s"}), Raises(StatusCode::Invalid));
140 }
141
TEST(MakeStruct,Array)142 TEST(MakeStruct, Array) {
143 std::vector<std::string> field_names{"i", "s"};
144
145 auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]");
146 auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");
147
148 EXPECT_THAT(MakeStruct({i32, str}, {"i", "s"}),
149 ResultWith(Datum(*StructArray::Make({i32, str}, field_names))));
150
151 // Scalars are broadcast to the length of the arrays
152 EXPECT_THAT(MakeStruct({i32, MakeScalar("aa")}, {"i", "s"}),
153 ResultWith(Datum(*StructArray::Make({i32, str}, field_names))));
154
155 // Array length mismatch
156 EXPECT_THAT(MakeStruct({i32->Slice(1), str}, field_names), Raises(StatusCode::Invalid));
157 }
158
TEST(MakeStruct,NullableMetadataPassedThru)159 TEST(MakeStruct, NullableMetadataPassedThru) {
160 auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]");
161 auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");
162
163 std::vector<std::string> field_names{"i", "s"};
164 std::vector<bool> nullability{true, false};
165 std::vector<std::shared_ptr<const KeyValueMetadata>> metadata = {
166 key_value_metadata({"a", "b"}, {"ALPHA", "BRAVO"}), nullptr};
167
168 ASSERT_OK_AND_ASSIGN(auto proj,
169 MakeStruct({i32, str}, field_names, nullability, metadata));
170
171 AssertTypeEqual(*proj.type(), StructType({
172 field("i", int32(), /*nullable=*/true, metadata[0]),
173 field("s", utf8(), /*nullable=*/false, nullptr),
174 }));
175
176 // error: projecting an array containing nulls with nullable=false
177 EXPECT_THAT(MakeStruct({i32, ArrayFromJSON(utf8(), R"(["aa", null, "aa"])")},
178 field_names, nullability, metadata),
179 Raises(StatusCode::Invalid));
180 }
181
TEST(MakeStruct,ChunkedArray)182 TEST(MakeStruct, ChunkedArray) {
183 std::vector<std::string> field_names{"i", "s"};
184
185 auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]");
186 auto i32_1 = ArrayFromJSON(int32(), "[]");
187 auto i32_2 = ArrayFromJSON(int32(), "[32, 0]");
188
189 auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");
190 auto str_1 = ArrayFromJSON(utf8(), "[]");
191 auto str_2 = ArrayFromJSON(utf8(), R"(["aa", "aa"])");
192
193 ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2}));
194 ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2}));
195
196 ASSERT_OK_AND_ASSIGN(auto expected_0, StructArray::Make({i32_0, str_0}, field_names));
197 ASSERT_OK_AND_ASSIGN(auto expected_1, StructArray::Make({i32_1, str_1}, field_names));
198 ASSERT_OK_AND_ASSIGN(auto expected_2, StructArray::Make({i32_2, str_2}, field_names));
199 ASSERT_OK_AND_ASSIGN(Datum expected,
200 ChunkedArray::Make({expected_0, expected_1, expected_2}));
201
202 ASSERT_OK_AND_EQ(expected, MakeStruct({i32, str}, field_names));
203
204 // Scalars are broadcast to the length of the arrays
205 ASSERT_OK_AND_EQ(expected, MakeStruct({i32, MakeScalar("aa")}, field_names));
206
207 // Array length mismatch
208 ASSERT_RAISES(Invalid, MakeStruct({i32->Slice(1), str}, field_names));
209 }
210
TEST(MakeStruct,ChunkedArrayDifferentChunking)211 TEST(MakeStruct, ChunkedArrayDifferentChunking) {
212 std::vector<std::string> field_names{"i", "s"};
213
214 auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]");
215 auto i32_1 = ArrayFromJSON(int32(), "[]");
216 auto i32_2 = ArrayFromJSON(int32(), "[32, 0]");
217
218 auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa"])");
219 auto str_1 = ArrayFromJSON(utf8(), R"(["aa"])");
220 auto str_2 = ArrayFromJSON(utf8(), R"([])");
221 auto str_3 = ArrayFromJSON(utf8(), R"(["aa", "aa"])");
222
223 ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2}));
224 ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2, str_3}));
225
226 std::vector<ArrayVector> expected_rechunked =
227 ::arrow::internal::RechunkArraysConsistently({i32->chunks(), str->chunks()});
228 ASSERT_EQ(expected_rechunked[0].size(), expected_rechunked[1].size());
229
230 ArrayVector expected_chunks(expected_rechunked[0].size());
231 for (size_t i = 0; i < expected_chunks.size(); ++i) {
232 ASSERT_OK_AND_ASSIGN(expected_chunks[i], StructArray::Make({expected_rechunked[0][i],
233 expected_rechunked[1][i]},
234 field_names));
235 }
236
237 ASSERT_OK_AND_ASSIGN(Datum expected, ChunkedArray::Make(expected_chunks));
238
239 ASSERT_OK_AND_EQ(expected, MakeStruct({i32, str}, field_names));
240
241 // Scalars are broadcast to the length of the arrays
242 ASSERT_OK_AND_EQ(expected, MakeStruct({i32, MakeScalar("aa")}, field_names));
243
244 // Array length mismatch
245 ASSERT_RAISES(Invalid, MakeStruct({i32->Slice(1), str}, field_names));
246 }
247
248 } // namespace compute
249 } // namespace arrow
250