1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23
24 #include <gtest/gtest.h>
25
26 #include "arrow/array.h"
27 #include "arrow/buffer.h"
28 #include "arrow/io/memory.h"
29 #include "arrow/ipc/feather.h"
30 #include "arrow/ipc/test_common.h"
31 #include "arrow/record_batch.h"
32 #include "arrow/status.h"
33 #include "arrow/table.h"
34 #include "arrow/testing/gtest_util.h"
35 #include "arrow/type.h"
36 #include "arrow/util/checked_cast.h"
37 #include "arrow/util/compression.h"
38
39 namespace arrow {
40
41 using internal::checked_cast;
42
43 namespace ipc {
44 namespace feather {
45
46 struct TestParam {
TestParamarrow::ipc::feather::TestParam47 TestParam(int arg_version,
48 Compression::type arg_compression = Compression::UNCOMPRESSED)
49 : version(arg_version), compression(arg_compression) {}
50
51 int version;
52 Compression::type compression;
53 };
54
55 class TestFeather : public ::testing::TestWithParam<TestParam> {
56 public:
SetUp()57 void SetUp() { Initialize(); }
58
Initialize()59 void Initialize() { ASSERT_OK_AND_ASSIGN(stream_, io::BufferOutputStream::Create()); }
60
GetProperties()61 WriteProperties GetProperties() {
62 auto param = GetParam();
63
64 auto props = WriteProperties::Defaults();
65 props.version = param.version;
66
67 // Don't fail if the build doesn't have LZ4_FRAME or ZSTD enabled
68 if (util::Codec::IsAvailable(param.compression)) {
69 props.compression = param.compression;
70 } else {
71 props.compression = Compression::UNCOMPRESSED;
72 }
73 return props;
74 }
75
DoWrite(const Table & table)76 void DoWrite(const Table& table) {
77 Initialize();
78 ASSERT_OK(WriteTable(table, stream_.get(), GetProperties()));
79 ASSERT_OK_AND_ASSIGN(output_, stream_->Finish());
80 auto buffer = std::make_shared<io::BufferReader>(output_);
81 ASSERT_OK_AND_ASSIGN(reader_, Reader::Open(buffer));
82 }
83
CheckSlice(std::shared_ptr<RecordBatch> batch,int start,int size)84 void CheckSlice(std::shared_ptr<RecordBatch> batch, int start, int size) {
85 batch = batch->Slice(start, size);
86 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
87
88 DoWrite(*table);
89 std::shared_ptr<Table> result;
90 ASSERT_OK(reader_->Read(&result));
91 if (table->num_rows() > 0) {
92 AssertTablesEqual(*table, *result);
93 } else {
94 ASSERT_EQ(0, result->num_rows());
95 ASSERT_TRUE(result->schema()->Equals(*table->schema()));
96 }
97 }
98
CheckSlices(std::shared_ptr<RecordBatch> batch)99 void CheckSlices(std::shared_ptr<RecordBatch> batch) {
100 std::vector<int> starts = {0, 1, 300, 301, 302, 303, 304, 305, 306, 307};
101 std::vector<int> sizes = {0, 1, 7, 8, 30, 32, 100};
102 for (auto start : starts) {
103 for (auto size : sizes) {
104 CheckSlice(batch, start, size);
105 }
106 }
107 }
108
CheckRoundtrip(std::shared_ptr<RecordBatch> batch)109 void CheckRoundtrip(std::shared_ptr<RecordBatch> batch) {
110 std::vector<std::shared_ptr<RecordBatch>> batches = {batch};
111 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(batches));
112
113 DoWrite(*table);
114
115 std::shared_ptr<Table> read_table;
116 ASSERT_OK(reader_->Read(&read_table));
117 AssertTablesEqual(*table, *read_table);
118 }
119
120 protected:
121 std::shared_ptr<io::BufferOutputStream> stream_;
122 std::shared_ptr<Reader> reader_;
123 std::shared_ptr<Buffer> output_;
124 };
125
TEST(TestFeatherWriteProperties,Defaults)126 TEST(TestFeatherWriteProperties, Defaults) {
127 auto props = WriteProperties::Defaults();
128
129 #ifdef ARROW_WITH_LZ4
130 ASSERT_EQ(Compression::LZ4_FRAME, props.compression);
131 #else
132 ASSERT_EQ(Compression::UNCOMPRESSED, props.compression);
133 #endif
134 }
135
TEST_P(TestFeather,ReadIndicesOrNames)136 TEST_P(TestFeather, ReadIndicesOrNames) {
137 std::shared_ptr<RecordBatch> batch1;
138 ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1));
139
140 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch1}));
141
142 DoWrite(*table);
143
144 auto expected = Table::Make(schema({field("f1", int32())}), {batch1->column(1)});
145
146 std::shared_ptr<Table> result1, result2;
147
148 std::vector<int> indices = {1};
149 ASSERT_OK(reader_->Read(indices, &result1));
150 AssertTablesEqual(*expected, *result1);
151
152 std::vector<std::string> names = {"f1"};
153 ASSERT_OK(reader_->Read(names, &result2));
154 AssertTablesEqual(*expected, *result2);
155 }
156
TEST_P(TestFeather,EmptyTable)157 TEST_P(TestFeather, EmptyTable) {
158 std::vector<std::shared_ptr<ChunkedArray>> columns;
159 auto table = Table::Make(schema({}), columns, 0);
160
161 DoWrite(*table);
162
163 std::shared_ptr<Table> result;
164 ASSERT_OK(reader_->Read(&result));
165 AssertTablesEqual(*table, *result);
166 }
167
TEST_P(TestFeather,SetNumRows)168 TEST_P(TestFeather, SetNumRows) {
169 std::vector<std::shared_ptr<ChunkedArray>> columns;
170 auto table = Table::Make(schema({}), columns, 1000);
171 DoWrite(*table);
172 std::shared_ptr<Table> result;
173 ASSERT_OK(reader_->Read(&result));
174 ASSERT_EQ(1000, result->num_rows());
175 }
176
TEST_P(TestFeather,PrimitiveRoundTrip)177 TEST_P(TestFeather, PrimitiveRoundTrip) {
178 std::shared_ptr<RecordBatch> batch;
179 ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch));
180
181 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
182
183 DoWrite(*table);
184
185 std::shared_ptr<Table> result;
186 ASSERT_OK(reader_->Read(&result));
187 AssertTablesEqual(*table, *result);
188 }
189
TEST_P(TestFeather,CategoryRoundtrip)190 TEST_P(TestFeather, CategoryRoundtrip) {
191 std::shared_ptr<RecordBatch> batch;
192 ASSERT_OK(ipc::test::MakeDictionaryFlat(&batch));
193 CheckRoundtrip(batch);
194 }
195
TEST_P(TestFeather,TimeTypes)196 TEST_P(TestFeather, TimeTypes) {
197 std::vector<bool> is_valid = {true, true, true, false, true, true, true};
198 auto f0 = field("f0", date32());
199 auto f1 = field("f1", time32(TimeUnit::MILLI));
200 auto f2 = field("f2", timestamp(TimeUnit::NANO));
201 auto f3 = field("f3", timestamp(TimeUnit::SECOND, "US/Los_Angeles"));
202 auto schema = ::arrow::schema({f0, f1, f2, f3});
203
204 std::vector<int64_t> values64_vec = {0, 1, 2, 3, 4, 5, 6};
205 std::shared_ptr<Array> values64;
206 ArrayFromVector<Int64Type, int64_t>(is_valid, values64_vec, &values64);
207
208 std::vector<int32_t> values32_vec = {10, 11, 12, 13, 14, 15, 16};
209 std::shared_ptr<Array> values32;
210 ArrayFromVector<Int32Type, int32_t>(is_valid, values32_vec, &values32);
211
212 std::vector<int32_t> date_values_vec = {20, 21, 22, 23, 24, 25, 26};
213 std::shared_ptr<Array> date_array;
214 ArrayFromVector<Date32Type, int32_t>(is_valid, date_values_vec, &date_array);
215
216 const auto& prim_values64 = checked_cast<const PrimitiveArray&>(*values64);
217 BufferVector buffers64 = {prim_values64.null_bitmap(), prim_values64.values()};
218
219 const auto& prim_values32 = checked_cast<const PrimitiveArray&>(*values32);
220 BufferVector buffers32 = {prim_values32.null_bitmap(), prim_values32.values()};
221
222 // Push date32 ArrayData
223 std::vector<std::shared_ptr<ArrayData>> arrays;
224 arrays.push_back(date_array->data());
225
226 // Create time32 ArrayData
227 arrays.emplace_back(ArrayData::Make(schema->field(1)->type(), values32->length(),
228 BufferVector(buffers32), values32->null_count(),
229 0));
230
231 // Create timestamp ArrayData
232 for (int i = 2; i < schema->num_fields(); ++i) {
233 arrays.emplace_back(ArrayData::Make(schema->field(i)->type(), values64->length(),
234 BufferVector(buffers64), values64->null_count(),
235 0));
236 }
237
238 auto batch = RecordBatch::Make(schema, 7, std::move(arrays));
239 CheckRoundtrip(batch);
240 }
241
TEST_P(TestFeather,VLenPrimitiveRoundTrip)242 TEST_P(TestFeather, VLenPrimitiveRoundTrip) {
243 std::shared_ptr<RecordBatch> batch;
244 ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch));
245 CheckRoundtrip(batch);
246 }
247
TEST_P(TestFeather,PrimitiveNullRoundTrip)248 TEST_P(TestFeather, PrimitiveNullRoundTrip) {
249 std::shared_ptr<RecordBatch> batch;
250 ASSERT_OK(ipc::test::MakeNullRecordBatch(&batch));
251
252 ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
253
254 DoWrite(*table);
255
256 std::shared_ptr<Table> result;
257 ASSERT_OK(reader_->Read(&result));
258
259 if (GetParam().version == kFeatherV1Version) {
260 std::vector<std::shared_ptr<Array>> expected_fields;
261 for (int i = 0; i < batch->num_columns(); ++i) {
262 ASSERT_EQ(batch->column_name(i), reader_->schema()->field(i)->name());
263 StringArray str_values(batch->column(i)->length(), nullptr, nullptr,
264 batch->column(i)->null_bitmap(),
265 batch->column(i)->null_count());
266 AssertArraysEqual(str_values, *result->column(i)->chunk(0));
267 }
268 } else {
269 AssertTablesEqual(*table, *result);
270 }
271 }
272
TEST_P(TestFeather,SliceRoundTrip)273 TEST_P(TestFeather, SliceRoundTrip) {
274 std::shared_ptr<RecordBatch> batch;
275 ASSERT_OK(ipc::test::MakeIntBatchSized(600, &batch));
276 CheckSlices(batch);
277 }
278
TEST_P(TestFeather,SliceStringsRoundTrip)279 TEST_P(TestFeather, SliceStringsRoundTrip) {
280 std::shared_ptr<RecordBatch> batch;
281 ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch, /*with_nulls=*/true));
282 CheckSlices(batch);
283 }
284
TEST_P(TestFeather,SliceBooleanRoundTrip)285 TEST_P(TestFeather, SliceBooleanRoundTrip) {
286 std::shared_ptr<RecordBatch> batch;
287 ASSERT_OK(ipc::test::MakeBooleanBatchSized(600, &batch));
288 CheckSlices(batch);
289 }
290
291 INSTANTIATE_TEST_SUITE_P(
292 FeatherTests, TestFeather,
293 ::testing::Values(TestParam(kFeatherV1Version), TestParam(kFeatherV2Version),
294 TestParam(kFeatherV2Version, Compression::LZ4_FRAME),
295 TestParam(kFeatherV2Version, Compression::ZSTD)));
296
297 } // namespace feather
298 } // namespace ipc
299 } // namespace arrow
300