1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23
24 #include <gtest/gtest.h>
25
26 #include "arrow/array.h"
27 #include "arrow/buffer.h"
28 #include "arrow/io/memory.h"
29 #include "arrow/ipc/feather.h"
30 #include "arrow/ipc/test_common.h"
31 #include "arrow/record_batch.h"
32 #include "arrow/status.h"
33 #include "arrow/table.h"
34 #include "arrow/testing/gtest_util.h"
35 #include "arrow/type.h"
36 #include "arrow/util/checked_cast.h"
37 #include "arrow/util/compression.h"
38
39 namespace arrow {
40
41 using internal::checked_cast;
42
43 namespace ipc {
44 namespace feather {
45
46 struct TestParam {
TestParamarrow::ipc::feather::TestParam47 TestParam(int arg_version,
48 Compression::type arg_compression = Compression::UNCOMPRESSED)
49 : version(arg_version), compression(arg_compression) {}
50
51 int version;
52 Compression::type compression;
53 };
54
PrintTo(const TestParam & p,std::ostream * os)55 void PrintTo(const TestParam& p, std::ostream* os) {
56 *os << "{version = " << p.version
57 << ", compression = " << ::arrow::util::Codec::GetCodecAsString(p.compression)
58 << "}";
59 }
60
61 class TestFeatherBase {
62 public:
SetUp()63 void SetUp() { Initialize(); }
64
Initialize()65 void Initialize() { ASSERT_OK_AND_ASSIGN(stream_, io::BufferOutputStream::Create()); }
66
67 virtual WriteProperties GetProperties() = 0;
68
DoWrite(const Table & table)69 void DoWrite(const Table& table) {
70 Initialize();
71 ASSERT_OK(WriteTable(table, stream_.get(), GetProperties()));
72 ASSERT_OK_AND_ASSIGN(output_, stream_->Finish());
73 auto buffer = std::make_shared<io::BufferReader>(output_);
74 ASSERT_OK_AND_ASSIGN(reader_, Reader::Open(buffer));
75 }
76
CheckSlice(std::shared_ptr<RecordBatch> batch,int start,int size)77 void CheckSlice(std::shared_ptr<RecordBatch> batch, int start, int size) {
78 batch = batch->Slice(start, size);
79 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
80
81 DoWrite(*table);
82 std::shared_ptr<Table> result;
83 ASSERT_OK(reader_->Read(&result));
84 ASSERT_OK(result->ValidateFull());
85 if (table->num_rows() > 0) {
86 AssertTablesEqual(*table, *result);
87 } else {
88 ASSERT_EQ(0, result->num_rows());
89 ASSERT_TRUE(result->schema()->Equals(*table->schema()));
90 }
91 }
92
CheckSlices(std::shared_ptr<RecordBatch> batch)93 void CheckSlices(std::shared_ptr<RecordBatch> batch) {
94 std::vector<int> starts = {0, 1, 300, 301, 302, 303, 304, 305, 306, 307};
95 std::vector<int> sizes = {0, 1, 7, 8, 30, 32, 100};
96 for (auto start : starts) {
97 for (auto size : sizes) {
98 CheckSlice(batch, start, size);
99 }
100 }
101 }
102
CheckRoundtrip(std::shared_ptr<RecordBatch> batch)103 void CheckRoundtrip(std::shared_ptr<RecordBatch> batch) {
104 std::vector<std::shared_ptr<RecordBatch>> batches = {batch};
105 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(batches));
106
107 DoWrite(*table);
108
109 std::shared_ptr<Table> read_table;
110 ASSERT_OK(reader_->Read(&read_table));
111 ASSERT_OK(read_table->ValidateFull());
112 AssertTablesEqual(*table, *read_table);
113 }
114
115 protected:
116 std::shared_ptr<io::BufferOutputStream> stream_;
117 std::shared_ptr<Reader> reader_;
118 std::shared_ptr<Buffer> output_;
119 };
120
121 class TestFeather : public ::testing::TestWithParam<TestParam>, public TestFeatherBase {
122 public:
SetUp()123 void SetUp() { TestFeatherBase::SetUp(); }
124
GetProperties()125 WriteProperties GetProperties() {
126 auto param = GetParam();
127
128 auto props = WriteProperties::Defaults();
129 props.version = param.version;
130
131 // Don't fail if the build doesn't have LZ4_FRAME or ZSTD enabled
132 if (util::Codec::IsAvailable(param.compression)) {
133 props.compression = param.compression;
134 } else {
135 props.compression = Compression::UNCOMPRESSED;
136 }
137 return props;
138 }
139 };
140
141 class TestFeatherRoundTrip : public ::testing::TestWithParam<ipc::test::MakeRecordBatch*>,
142 public TestFeatherBase {
143 public:
SetUp()144 void SetUp() { TestFeatherBase::SetUp(); }
145
GetProperties()146 WriteProperties GetProperties() {
147 auto props = WriteProperties::Defaults();
148 props.version = kFeatherV2Version;
149
150 // Don't fail if the build doesn't have LZ4_FRAME or ZSTD enabled
151 if (!util::Codec::IsAvailable(props.compression)) {
152 props.compression = Compression::UNCOMPRESSED;
153 }
154 return props;
155 }
156 };
157
TEST(TestFeatherWriteProperties,Defaults)158 TEST(TestFeatherWriteProperties, Defaults) {
159 auto props = WriteProperties::Defaults();
160
161 #ifdef ARROW_WITH_LZ4
162 ASSERT_EQ(Compression::LZ4_FRAME, props.compression);
163 #else
164 ASSERT_EQ(Compression::UNCOMPRESSED, props.compression);
165 #endif
166 }
167
TEST_P(TestFeather,ReadIndicesOrNames)168 TEST_P(TestFeather, ReadIndicesOrNames) {
169 std::shared_ptr<RecordBatch> batch1;
170 ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1));
171
172 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch1}));
173
174 DoWrite(*table);
175
176 // int32 type is at the column f4 of the result of MakeIntRecordBatch
177 auto expected = Table::Make(schema({field("f4", int32())}), {batch1->column(4)});
178
179 std::shared_ptr<Table> result1, result2;
180
181 std::vector<int> indices = {4};
182 ASSERT_OK(reader_->Read(indices, &result1));
183 AssertTablesEqual(*expected, *result1);
184
185 std::vector<std::string> names = {"f4"};
186 ASSERT_OK(reader_->Read(names, &result2));
187 AssertTablesEqual(*expected, *result2);
188 }
189
TEST_P(TestFeather,EmptyTable)190 TEST_P(TestFeather, EmptyTable) {
191 std::vector<std::shared_ptr<ChunkedArray>> columns;
192 auto table = Table::Make(schema({}), columns, 0);
193
194 DoWrite(*table);
195
196 std::shared_ptr<Table> result;
197 ASSERT_OK(reader_->Read(&result));
198 AssertTablesEqual(*table, *result);
199 }
200
TEST_P(TestFeather,SetNumRows)201 TEST_P(TestFeather, SetNumRows) {
202 std::vector<std::shared_ptr<ChunkedArray>> columns;
203 auto table = Table::Make(schema({}), columns, 1000);
204 DoWrite(*table);
205 std::shared_ptr<Table> result;
206 ASSERT_OK(reader_->Read(&result));
207 ASSERT_EQ(1000, result->num_rows());
208 }
209
TEST_P(TestFeather,PrimitiveIntRoundTrip)210 TEST_P(TestFeather, PrimitiveIntRoundTrip) {
211 std::shared_ptr<RecordBatch> batch;
212 ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch));
213 CheckRoundtrip(batch);
214 }
215
TEST_P(TestFeather,PrimitiveFloatRoundTrip)216 TEST_P(TestFeather, PrimitiveFloatRoundTrip) {
217 std::shared_ptr<RecordBatch> batch;
218 ASSERT_OK(ipc::test::MakeFloat3264Batch(&batch));
219 CheckRoundtrip(batch);
220 }
221
TEST_P(TestFeather,CategoryRoundtrip)222 TEST_P(TestFeather, CategoryRoundtrip) {
223 std::shared_ptr<RecordBatch> batch;
224 ASSERT_OK(ipc::test::MakeDictionaryFlat(&batch));
225 CheckRoundtrip(batch);
226 }
227
TEST_P(TestFeather,TimeTypes)228 TEST_P(TestFeather, TimeTypes) {
229 std::vector<bool> is_valid = {true, true, true, false, true, true, true};
230 auto f0 = field("f0", date32());
231 auto f1 = field("f1", time32(TimeUnit::MILLI));
232 auto f2 = field("f2", timestamp(TimeUnit::NANO));
233 auto f3 = field("f3", timestamp(TimeUnit::SECOND, "US/Los_Angeles"));
234 auto schema = ::arrow::schema({f0, f1, f2, f3});
235
236 std::vector<int64_t> values64_vec = {0, 1, 2, 3, 4, 5, 6};
237 std::shared_ptr<Array> values64;
238 ArrayFromVector<Int64Type, int64_t>(is_valid, values64_vec, &values64);
239
240 std::vector<int32_t> values32_vec = {10, 11, 12, 13, 14, 15, 16};
241 std::shared_ptr<Array> values32;
242 ArrayFromVector<Int32Type, int32_t>(is_valid, values32_vec, &values32);
243
244 std::vector<int32_t> date_values_vec = {20, 21, 22, 23, 24, 25, 26};
245 std::shared_ptr<Array> date_array;
246 ArrayFromVector<Date32Type, int32_t>(is_valid, date_values_vec, &date_array);
247
248 const auto& prim_values64 = checked_cast<const PrimitiveArray&>(*values64);
249 BufferVector buffers64 = {prim_values64.null_bitmap(), prim_values64.values()};
250
251 const auto& prim_values32 = checked_cast<const PrimitiveArray&>(*values32);
252 BufferVector buffers32 = {prim_values32.null_bitmap(), prim_values32.values()};
253
254 // Push date32 ArrayData
255 std::vector<std::shared_ptr<ArrayData>> arrays;
256 arrays.push_back(date_array->data());
257
258 // Create time32 ArrayData
259 arrays.emplace_back(ArrayData::Make(schema->field(1)->type(), values32->length(),
260 BufferVector(buffers32), values32->null_count(),
261 0));
262
263 // Create timestamp ArrayData
264 for (int i = 2; i < schema->num_fields(); ++i) {
265 arrays.emplace_back(ArrayData::Make(schema->field(i)->type(), values64->length(),
266 BufferVector(buffers64), values64->null_count(),
267 0));
268 }
269
270 auto batch = RecordBatch::Make(schema, 7, std::move(arrays));
271 CheckRoundtrip(batch);
272 }
273
TEST_P(TestFeather,VLenPrimitiveRoundTrip)274 TEST_P(TestFeather, VLenPrimitiveRoundTrip) {
275 std::shared_ptr<RecordBatch> batch;
276 ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch));
277 CheckRoundtrip(batch);
278 }
279
TEST_P(TestFeather,PrimitiveNullRoundTrip)280 TEST_P(TestFeather, PrimitiveNullRoundTrip) {
281 std::shared_ptr<RecordBatch> batch;
282 ASSERT_OK(ipc::test::MakeNullRecordBatch(&batch));
283
284 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
285
286 DoWrite(*table);
287
288 std::shared_ptr<Table> result;
289 ASSERT_OK(reader_->Read(&result));
290
291 if (GetParam().version == kFeatherV1Version) {
292 std::vector<std::shared_ptr<Array>> expected_fields;
293 for (int i = 0; i < batch->num_columns(); ++i) {
294 ASSERT_EQ(batch->column_name(i), reader_->schema()->field(i)->name());
295 ASSERT_OK_AND_ASSIGN(auto expected, MakeArrayOfNull(utf8(), batch->num_rows()));
296 AssertArraysEqual(*expected, *result->column(i)->chunk(0));
297 }
298 } else {
299 AssertTablesEqual(*table, *result);
300 }
301 }
302
TEST_P(TestFeather,SliceIntRoundTrip)303 TEST_P(TestFeather, SliceIntRoundTrip) {
304 std::shared_ptr<RecordBatch> batch;
305 ASSERT_OK(ipc::test::MakeIntBatchSized(600, &batch));
306 CheckSlices(batch);
307 }
308
TEST_P(TestFeather,SliceFloatRoundTrip)309 TEST_P(TestFeather, SliceFloatRoundTrip) {
310 std::shared_ptr<RecordBatch> batch;
311 // Float16 is not supported by FeatherV1
312 ASSERT_OK(ipc::test::MakeFloat3264BatchSized(600, &batch));
313 CheckSlices(batch);
314 }
315
TEST_P(TestFeather,SliceStringsRoundTrip)316 TEST_P(TestFeather, SliceStringsRoundTrip) {
317 std::shared_ptr<RecordBatch> batch;
318 ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch, /*with_nulls=*/true));
319 CheckSlices(batch);
320 }
321
TEST_P(TestFeather,SliceBooleanRoundTrip)322 TEST_P(TestFeather, SliceBooleanRoundTrip) {
323 std::shared_ptr<RecordBatch> batch;
324 ASSERT_OK(ipc::test::MakeBooleanBatchSized(600, &batch));
325 CheckSlices(batch);
326 }
327
328 INSTANTIATE_TEST_SUITE_P(
329 FeatherTests, TestFeather,
330 ::testing::Values(TestParam(kFeatherV1Version), TestParam(kFeatherV2Version),
331 TestParam(kFeatherV2Version, Compression::LZ4_FRAME),
332 TestParam(kFeatherV2Version, Compression::ZSTD)));
333
334 namespace {
335
336 const std::vector<test::MakeRecordBatch*> kBatchCases = {
337 &ipc::test::MakeIntRecordBatch,
338 &ipc::test::MakeListRecordBatch,
339 &ipc::test::MakeFixedSizeListRecordBatch,
340 &ipc::test::MakeNonNullRecordBatch,
341 &ipc::test::MakeDeeplyNestedList,
342 &ipc::test::MakeStringTypesRecordBatchWithNulls,
343 &ipc::test::MakeStruct,
344 &ipc::test::MakeUnion,
345 &ipc::test::MakeDictionary,
346 &ipc::test::MakeNestedDictionary,
347 &ipc::test::MakeMap,
348 &ipc::test::MakeMapOfDictionary,
349 &ipc::test::MakeDates,
350 &ipc::test::MakeTimestamps,
351 &ipc::test::MakeTimes,
352 &ipc::test::MakeFWBinary,
353 &ipc::test::MakeNull,
354 &ipc::test::MakeDecimal,
355 &ipc::test::MakeBooleanBatch,
356 &ipc::test::MakeFloatBatch,
357 &ipc::test::MakeIntervals};
358
359 } // namespace
360
TEST_P(TestFeatherRoundTrip,RoundTrip)361 TEST_P(TestFeatherRoundTrip, RoundTrip) {
362 std::shared_ptr<RecordBatch> batch;
363 ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
364
365 CheckRoundtrip(batch);
366 }
367
368 INSTANTIATE_TEST_SUITE_P(FeatherRoundTripTests, TestFeatherRoundTrip,
369 ::testing::ValuesIn(kBatchCases));
370
371 } // namespace feather
372 } // namespace ipc
373 } // namespace arrow
374