1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23 
24 #include <gtest/gtest.h>
25 
26 #include "arrow/array.h"
27 #include "arrow/buffer.h"
28 #include "arrow/io/memory.h"
29 #include "arrow/ipc/feather.h"
30 #include "arrow/ipc/test_common.h"
31 #include "arrow/record_batch.h"
32 #include "arrow/status.h"
33 #include "arrow/table.h"
34 #include "arrow/testing/gtest_util.h"
35 #include "arrow/type.h"
36 #include "arrow/util/checked_cast.h"
37 #include "arrow/util/compression.h"
38 
39 namespace arrow {
40 
41 using internal::checked_cast;
42 
43 namespace ipc {
44 namespace feather {
45 
46 struct TestParam {
TestParamarrow::ipc::feather::TestParam47   TestParam(int arg_version,
48             Compression::type arg_compression = Compression::UNCOMPRESSED)
49       : version(arg_version), compression(arg_compression) {}
50 
51   int version;
52   Compression::type compression;
53 };
54 
PrintTo(const TestParam & p,std::ostream * os)55 void PrintTo(const TestParam& p, std::ostream* os) {
56   *os << "{version = " << p.version
57       << ", compression = " << ::arrow::util::Codec::GetCodecAsString(p.compression)
58       << "}";
59 }
60 
61 class TestFeatherBase {
62  public:
SetUp()63   void SetUp() { Initialize(); }
64 
Initialize()65   void Initialize() { ASSERT_OK_AND_ASSIGN(stream_, io::BufferOutputStream::Create()); }
66 
67   virtual WriteProperties GetProperties() = 0;
68 
DoWrite(const Table & table)69   void DoWrite(const Table& table) {
70     Initialize();
71     ASSERT_OK(WriteTable(table, stream_.get(), GetProperties()));
72     ASSERT_OK_AND_ASSIGN(output_, stream_->Finish());
73     auto buffer = std::make_shared<io::BufferReader>(output_);
74     ASSERT_OK_AND_ASSIGN(reader_, Reader::Open(buffer));
75   }
76 
CheckSlice(std::shared_ptr<RecordBatch> batch,int start,int size)77   void CheckSlice(std::shared_ptr<RecordBatch> batch, int start, int size) {
78     batch = batch->Slice(start, size);
79     ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
80 
81     DoWrite(*table);
82     std::shared_ptr<Table> result;
83     ASSERT_OK(reader_->Read(&result));
84     ASSERT_OK(result->ValidateFull());
85     if (table->num_rows() > 0) {
86       AssertTablesEqual(*table, *result);
87     } else {
88       ASSERT_EQ(0, result->num_rows());
89       ASSERT_TRUE(result->schema()->Equals(*table->schema()));
90     }
91   }
92 
CheckSlices(std::shared_ptr<RecordBatch> batch)93   void CheckSlices(std::shared_ptr<RecordBatch> batch) {
94     std::vector<int> starts = {0, 1, 300, 301, 302, 303, 304, 305, 306, 307};
95     std::vector<int> sizes = {0, 1, 7, 8, 30, 32, 100};
96     for (auto start : starts) {
97       for (auto size : sizes) {
98         CheckSlice(batch, start, size);
99       }
100     }
101   }
102 
CheckRoundtrip(std::shared_ptr<RecordBatch> batch)103   void CheckRoundtrip(std::shared_ptr<RecordBatch> batch) {
104     std::vector<std::shared_ptr<RecordBatch>> batches = {batch};
105     ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(batches));
106 
107     DoWrite(*table);
108 
109     std::shared_ptr<Table> read_table;
110     ASSERT_OK(reader_->Read(&read_table));
111     ASSERT_OK(read_table->ValidateFull());
112     AssertTablesEqual(*table, *read_table);
113   }
114 
115  protected:
116   std::shared_ptr<io::BufferOutputStream> stream_;
117   std::shared_ptr<Reader> reader_;
118   std::shared_ptr<Buffer> output_;
119 };
120 
121 class TestFeather : public ::testing::TestWithParam<TestParam>, public TestFeatherBase {
122  public:
SetUp()123   void SetUp() { TestFeatherBase::SetUp(); }
124 
GetProperties()125   WriteProperties GetProperties() {
126     auto param = GetParam();
127 
128     auto props = WriteProperties::Defaults();
129     props.version = param.version;
130 
131     // Don't fail if the build doesn't have LZ4_FRAME or ZSTD enabled
132     if (util::Codec::IsAvailable(param.compression)) {
133       props.compression = param.compression;
134     } else {
135       props.compression = Compression::UNCOMPRESSED;
136     }
137     return props;
138   }
139 };
140 
141 class TestFeatherRoundTrip : public ::testing::TestWithParam<ipc::test::MakeRecordBatch*>,
142                              public TestFeatherBase {
143  public:
SetUp()144   void SetUp() { TestFeatherBase::SetUp(); }
145 
GetProperties()146   WriteProperties GetProperties() {
147     auto props = WriteProperties::Defaults();
148     props.version = kFeatherV2Version;
149 
150     // Don't fail if the build doesn't have LZ4_FRAME or ZSTD enabled
151     if (!util::Codec::IsAvailable(props.compression)) {
152       props.compression = Compression::UNCOMPRESSED;
153     }
154     return props;
155   }
156 };
157 
TEST(TestFeatherWriteProperties,Defaults)158 TEST(TestFeatherWriteProperties, Defaults) {
159   auto props = WriteProperties::Defaults();
160 
161 #ifdef ARROW_WITH_LZ4
162   ASSERT_EQ(Compression::LZ4_FRAME, props.compression);
163 #else
164   ASSERT_EQ(Compression::UNCOMPRESSED, props.compression);
165 #endif
166 }
167 
TEST_P(TestFeather,ReadIndicesOrNames)168 TEST_P(TestFeather, ReadIndicesOrNames) {
169   std::shared_ptr<RecordBatch> batch1;
170   ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1));
171 
172   ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch1}));
173 
174   DoWrite(*table);
175 
176   // int32 type is at the column f4 of the result of MakeIntRecordBatch
177   auto expected = Table::Make(schema({field("f4", int32())}), {batch1->column(4)});
178 
179   std::shared_ptr<Table> result1, result2;
180 
181   std::vector<int> indices = {4};
182   ASSERT_OK(reader_->Read(indices, &result1));
183   AssertTablesEqual(*expected, *result1);
184 
185   std::vector<std::string> names = {"f4"};
186   ASSERT_OK(reader_->Read(names, &result2));
187   AssertTablesEqual(*expected, *result2);
188 }
189 
TEST_P(TestFeather,EmptyTable)190 TEST_P(TestFeather, EmptyTable) {
191   std::vector<std::shared_ptr<ChunkedArray>> columns;
192   auto table = Table::Make(schema({}), columns, 0);
193 
194   DoWrite(*table);
195 
196   std::shared_ptr<Table> result;
197   ASSERT_OK(reader_->Read(&result));
198   AssertTablesEqual(*table, *result);
199 }
200 
TEST_P(TestFeather,SetNumRows)201 TEST_P(TestFeather, SetNumRows) {
202   std::vector<std::shared_ptr<ChunkedArray>> columns;
203   auto table = Table::Make(schema({}), columns, 1000);
204   DoWrite(*table);
205   std::shared_ptr<Table> result;
206   ASSERT_OK(reader_->Read(&result));
207   ASSERT_EQ(1000, result->num_rows());
208 }
209 
TEST_P(TestFeather,PrimitiveIntRoundTrip)210 TEST_P(TestFeather, PrimitiveIntRoundTrip) {
211   std::shared_ptr<RecordBatch> batch;
212   ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch));
213   CheckRoundtrip(batch);
214 }
215 
TEST_P(TestFeather,PrimitiveFloatRoundTrip)216 TEST_P(TestFeather, PrimitiveFloatRoundTrip) {
217   std::shared_ptr<RecordBatch> batch;
218   ASSERT_OK(ipc::test::MakeFloat3264Batch(&batch));
219   CheckRoundtrip(batch);
220 }
221 
TEST_P(TestFeather,CategoryRoundtrip)222 TEST_P(TestFeather, CategoryRoundtrip) {
223   std::shared_ptr<RecordBatch> batch;
224   ASSERT_OK(ipc::test::MakeDictionaryFlat(&batch));
225   CheckRoundtrip(batch);
226 }
227 
TEST_P(TestFeather,TimeTypes)228 TEST_P(TestFeather, TimeTypes) {
229   std::vector<bool> is_valid = {true, true, true, false, true, true, true};
230   auto f0 = field("f0", date32());
231   auto f1 = field("f1", time32(TimeUnit::MILLI));
232   auto f2 = field("f2", timestamp(TimeUnit::NANO));
233   auto f3 = field("f3", timestamp(TimeUnit::SECOND, "US/Los_Angeles"));
234   auto schema = ::arrow::schema({f0, f1, f2, f3});
235 
236   std::vector<int64_t> values64_vec = {0, 1, 2, 3, 4, 5, 6};
237   std::shared_ptr<Array> values64;
238   ArrayFromVector<Int64Type, int64_t>(is_valid, values64_vec, &values64);
239 
240   std::vector<int32_t> values32_vec = {10, 11, 12, 13, 14, 15, 16};
241   std::shared_ptr<Array> values32;
242   ArrayFromVector<Int32Type, int32_t>(is_valid, values32_vec, &values32);
243 
244   std::vector<int32_t> date_values_vec = {20, 21, 22, 23, 24, 25, 26};
245   std::shared_ptr<Array> date_array;
246   ArrayFromVector<Date32Type, int32_t>(is_valid, date_values_vec, &date_array);
247 
248   const auto& prim_values64 = checked_cast<const PrimitiveArray&>(*values64);
249   BufferVector buffers64 = {prim_values64.null_bitmap(), prim_values64.values()};
250 
251   const auto& prim_values32 = checked_cast<const PrimitiveArray&>(*values32);
252   BufferVector buffers32 = {prim_values32.null_bitmap(), prim_values32.values()};
253 
254   // Push date32 ArrayData
255   std::vector<std::shared_ptr<ArrayData>> arrays;
256   arrays.push_back(date_array->data());
257 
258   // Create time32 ArrayData
259   arrays.emplace_back(ArrayData::Make(schema->field(1)->type(), values32->length(),
260                                       BufferVector(buffers32), values32->null_count(),
261                                       0));
262 
263   // Create timestamp ArrayData
264   for (int i = 2; i < schema->num_fields(); ++i) {
265     arrays.emplace_back(ArrayData::Make(schema->field(i)->type(), values64->length(),
266                                         BufferVector(buffers64), values64->null_count(),
267                                         0));
268   }
269 
270   auto batch = RecordBatch::Make(schema, 7, std::move(arrays));
271   CheckRoundtrip(batch);
272 }
273 
TEST_P(TestFeather,VLenPrimitiveRoundTrip)274 TEST_P(TestFeather, VLenPrimitiveRoundTrip) {
275   std::shared_ptr<RecordBatch> batch;
276   ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch));
277   CheckRoundtrip(batch);
278 }
279 
TEST_P(TestFeather,PrimitiveNullRoundTrip)280 TEST_P(TestFeather, PrimitiveNullRoundTrip) {
281   std::shared_ptr<RecordBatch> batch;
282   ASSERT_OK(ipc::test::MakeNullRecordBatch(&batch));
283 
284   ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
285 
286   DoWrite(*table);
287 
288   std::shared_ptr<Table> result;
289   ASSERT_OK(reader_->Read(&result));
290 
291   if (GetParam().version == kFeatherV1Version) {
292     std::vector<std::shared_ptr<Array>> expected_fields;
293     for (int i = 0; i < batch->num_columns(); ++i) {
294       ASSERT_EQ(batch->column_name(i), reader_->schema()->field(i)->name());
295       ASSERT_OK_AND_ASSIGN(auto expected, MakeArrayOfNull(utf8(), batch->num_rows()));
296       AssertArraysEqual(*expected, *result->column(i)->chunk(0));
297     }
298   } else {
299     AssertTablesEqual(*table, *result);
300   }
301 }
302 
TEST_P(TestFeather,SliceIntRoundTrip)303 TEST_P(TestFeather, SliceIntRoundTrip) {
304   std::shared_ptr<RecordBatch> batch;
305   ASSERT_OK(ipc::test::MakeIntBatchSized(600, &batch));
306   CheckSlices(batch);
307 }
308 
TEST_P(TestFeather,SliceFloatRoundTrip)309 TEST_P(TestFeather, SliceFloatRoundTrip) {
310   std::shared_ptr<RecordBatch> batch;
311   // Float16 is not supported by FeatherV1
312   ASSERT_OK(ipc::test::MakeFloat3264BatchSized(600, &batch));
313   CheckSlices(batch);
314 }
315 
TEST_P(TestFeather,SliceStringsRoundTrip)316 TEST_P(TestFeather, SliceStringsRoundTrip) {
317   std::shared_ptr<RecordBatch> batch;
318   ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch, /*with_nulls=*/true));
319   CheckSlices(batch);
320 }
321 
TEST_P(TestFeather,SliceBooleanRoundTrip)322 TEST_P(TestFeather, SliceBooleanRoundTrip) {
323   std::shared_ptr<RecordBatch> batch;
324   ASSERT_OK(ipc::test::MakeBooleanBatchSized(600, &batch));
325   CheckSlices(batch);
326 }
327 
328 INSTANTIATE_TEST_SUITE_P(
329     FeatherTests, TestFeather,
330     ::testing::Values(TestParam(kFeatherV1Version), TestParam(kFeatherV2Version),
331                       TestParam(kFeatherV2Version, Compression::LZ4_FRAME),
332                       TestParam(kFeatherV2Version, Compression::ZSTD)));
333 
334 namespace {
335 
336 const std::vector<test::MakeRecordBatch*> kBatchCases = {
337     &ipc::test::MakeIntRecordBatch,
338     &ipc::test::MakeListRecordBatch,
339     &ipc::test::MakeFixedSizeListRecordBatch,
340     &ipc::test::MakeNonNullRecordBatch,
341     &ipc::test::MakeDeeplyNestedList,
342     &ipc::test::MakeStringTypesRecordBatchWithNulls,
343     &ipc::test::MakeStruct,
344     &ipc::test::MakeUnion,
345     &ipc::test::MakeDictionary,
346     &ipc::test::MakeNestedDictionary,
347     &ipc::test::MakeMap,
348     &ipc::test::MakeMapOfDictionary,
349     &ipc::test::MakeDates,
350     &ipc::test::MakeTimestamps,
351     &ipc::test::MakeTimes,
352     &ipc::test::MakeFWBinary,
353     &ipc::test::MakeNull,
354     &ipc::test::MakeDecimal,
355     &ipc::test::MakeBooleanBatch,
356     &ipc::test::MakeFloatBatch,
357     &ipc::test::MakeIntervals};
358 
359 }  // namespace
360 
TEST_P(TestFeatherRoundTrip,RoundTrip)361 TEST_P(TestFeatherRoundTrip, RoundTrip) {
362   std::shared_ptr<RecordBatch> batch;
363   ASSERT_OK((*GetParam())(&batch));  // NOLINT clang-tidy gtest issue
364 
365   CheckRoundtrip(batch);
366 }
367 
368 INSTANTIATE_TEST_SUITE_P(FeatherRoundTripTests, TestFeatherRoundTrip,
369                          ::testing::ValuesIn(kBatchCases));
370 
371 }  // namespace feather
372 }  // namespace ipc
373 }  // namespace arrow
374