1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <algorithm>
19 #include <array>
20 #include <cstdint>
21 #include <cstring>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <numeric>
26 #include <sstream>
27 #include <string>
28 #include <type_traits>
29 #include <vector>
30
31 #include <gtest/gtest.h>
32
33 #include "arrow/array.h"
34 #include "arrow/buffer.h"
35 #include "arrow/buffer_builder.h"
36 #include "arrow/extension_type.h"
37 #include "arrow/io/memory.h"
38 #include "arrow/ipc/reader.h"
39 #include "arrow/ipc/writer.h"
40 #include "arrow/record_batch.h"
41 #include "arrow/status.h"
42 #include "arrow/testing/extension_type.h"
43 #include "arrow/testing/gtest_common.h"
44 #include "arrow/testing/util.h"
45 #include "arrow/type.h"
46 #include "arrow/util/key_value_metadata.h"
47 #include "arrow/util/logging.h"
48
49 namespace arrow {
50
51 class Parametric1Array : public ExtensionArray {
52 public:
53 using ExtensionArray::ExtensionArray;
54 };
55
56 class Parametric2Array : public ExtensionArray {
57 public:
58 using ExtensionArray::ExtensionArray;
59 };
60
61 // A parametric type where the extension_name() is always the same
62 class Parametric1Type : public ExtensionType {
63 public:
Parametric1Type(int32_t parameter)64 explicit Parametric1Type(int32_t parameter)
65 : ExtensionType(int32()), parameter_(parameter) {}
66
parameter() const67 int32_t parameter() const { return parameter_; }
68
extension_name() const69 std::string extension_name() const override { return "parametric-type-1"; }
70
ExtensionEquals(const ExtensionType & other) const71 bool ExtensionEquals(const ExtensionType& other) const override {
72 const auto& other_ext = static_cast<const ExtensionType&>(other);
73 if (other_ext.extension_name() != this->extension_name()) {
74 return false;
75 }
76 return this->parameter() == static_cast<const Parametric1Type&>(other).parameter();
77 }
78
MakeArray(std::shared_ptr<ArrayData> data) const79 std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
80 return std::make_shared<Parametric1Array>(data);
81 }
82
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const83 Result<std::shared_ptr<DataType>> Deserialize(
84 std::shared_ptr<DataType> storage_type,
85 const std::string& serialized) const override {
86 DCHECK_EQ(4, serialized.size());
87 const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
88 DCHECK(storage_type->Equals(int32()));
89 return std::make_shared<Parametric1Type>(parameter);
90 }
91
Serialize() const92 std::string Serialize() const override {
93 std::string result(" ");
94 memcpy(&result[0], ¶meter_, sizeof(int32_t));
95 return result;
96 }
97
98 private:
99 int32_t parameter_;
100 };
101
102 // A parametric type where the extension_name() is different for each
103 // parameter, and must be separately registered
104 class Parametric2Type : public ExtensionType {
105 public:
Parametric2Type(int32_t parameter)106 explicit Parametric2Type(int32_t parameter)
107 : ExtensionType(int32()), parameter_(parameter) {}
108
parameter() const109 int32_t parameter() const { return parameter_; }
110
extension_name() const111 std::string extension_name() const override {
112 std::stringstream ss;
113 ss << "parametric-type-2<param=" << parameter_ << ">";
114 return ss.str();
115 }
116
ExtensionEquals(const ExtensionType & other) const117 bool ExtensionEquals(const ExtensionType& other) const override {
118 const auto& other_ext = static_cast<const ExtensionType&>(other);
119 if (other_ext.extension_name() != this->extension_name()) {
120 return false;
121 }
122 return this->parameter() == static_cast<const Parametric2Type&>(other).parameter();
123 }
124
MakeArray(std::shared_ptr<ArrayData> data) const125 std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
126 return std::make_shared<Parametric2Array>(data);
127 }
128
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const129 Result<std::shared_ptr<DataType>> Deserialize(
130 std::shared_ptr<DataType> storage_type,
131 const std::string& serialized) const override {
132 DCHECK_EQ(4, serialized.size());
133 const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
134 DCHECK(storage_type->Equals(int32()));
135 return std::make_shared<Parametric2Type>(parameter);
136 }
137
Serialize() const138 std::string Serialize() const override {
139 std::string result(" ");
140 memcpy(&result[0], ¶meter_, sizeof(int32_t));
141 return result;
142 }
143
144 private:
145 int32_t parameter_;
146 };
147
148 // An extension type with a non-primitive storage type
149 class ExtStructArray : public ExtensionArray {
150 public:
151 using ExtensionArray::ExtensionArray;
152 };
153
154 class ExtStructType : public ExtensionType {
155 public:
ExtStructType()156 ExtStructType()
157 : ExtensionType(
158 struct_({::arrow::field("a", int64()), ::arrow::field("b", float64())})) {}
159
extension_name() const160 std::string extension_name() const override { return "ext-struct-type"; }
161
ExtensionEquals(const ExtensionType & other) const162 bool ExtensionEquals(const ExtensionType& other) const override {
163 const auto& other_ext = static_cast<const ExtensionType&>(other);
164 if (other_ext.extension_name() != this->extension_name()) {
165 return false;
166 }
167 return true;
168 }
169
MakeArray(std::shared_ptr<ArrayData> data) const170 std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
171 return std::make_shared<ExtStructArray>(data);
172 }
173
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const174 Result<std::shared_ptr<DataType>> Deserialize(
175 std::shared_ptr<DataType> storage_type,
176 const std::string& serialized) const override {
177 if (serialized != "ext-struct-type-unique-code") {
178 return Status::Invalid("Type identifier did not match");
179 }
180 return std::make_shared<ExtStructType>();
181 }
182
Serialize() const183 std::string Serialize() const override { return "ext-struct-type-unique-code"; }
184 };
185
186 class TestExtensionType : public ::testing::Test {
187 public:
SetUp()188 void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<UuidType>())); }
189
TearDown()190 void TearDown() {
191 if (GetExtensionType("uuid")) {
192 ASSERT_OK(UnregisterExtensionType("uuid"));
193 }
194 }
195 };
196
TEST_F(TestExtensionType,ExtensionTypeTest)197 TEST_F(TestExtensionType, ExtensionTypeTest) {
198 auto type_not_exist = GetExtensionType("uuid-unknown");
199 ASSERT_EQ(type_not_exist, nullptr);
200
201 auto registered_type = GetExtensionType("uuid");
202 ASSERT_NE(registered_type, nullptr);
203
204 auto type = uuid();
205 ASSERT_EQ(type->id(), Type::EXTENSION);
206
207 const auto& ext_type = static_cast<const ExtensionType&>(*type);
208 std::string serialized = ext_type.Serialize();
209
210 ASSERT_OK_AND_ASSIGN(auto deserialized,
211 ext_type.Deserialize(fixed_size_binary(16), serialized));
212 ASSERT_TRUE(deserialized->Equals(*type));
213 ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
214 }
215
216 auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
__anon50113e460102(const std::shared_ptr<RecordBatch>& batch, std::shared_ptr<RecordBatch>* out) 217 std::shared_ptr<RecordBatch>* out) {
218 ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
219 ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
220 out_stream.get()));
221
222 ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
223
224 io::BufferReader reader(complete_ipc_stream);
225 std::shared_ptr<RecordBatchReader> batch_reader;
226 ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
227 ASSERT_OK(batch_reader->ReadNext(out));
228 };
229
TEST_F(TestExtensionType,IpcRoundtrip)230 TEST_F(TestExtensionType, IpcRoundtrip) {
231 auto ext_arr = ExampleUuid();
232 auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
233
234 std::shared_ptr<RecordBatch> read_batch;
235 RoundtripBatch(batch, &read_batch);
236 CompareBatch(*batch, *read_batch, false /* compare_metadata */);
237
238 // Wrap type in a ListArray and ensure it also makes it
239 auto offsets_arr = ArrayFromJSON(int32(), "[0, 0, 2, 4]");
240 ASSERT_OK_AND_ASSIGN(auto list_arr, ListArray::FromArrays(*offsets_arr, *ext_arr));
241 batch = RecordBatch::Make(schema({field("f0", list(uuid()))}), 3, {list_arr});
242 RoundtripBatch(batch, &read_batch);
243 CompareBatch(*batch, *read_batch, false /* compare_metadata */);
244 }
245
TEST_F(TestExtensionType,UnrecognizedExtension)246 TEST_F(TestExtensionType, UnrecognizedExtension) {
247 auto ext_arr = ExampleUuid();
248 auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
249
250 auto storage_arr = static_cast<const ExtensionArray&>(*ext_arr).storage();
251
252 // Write full IPC stream including schema, then unregister type, then read
253 // and ensure that a plain instance of the storage type is created
254 ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
255 ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
256 out_stream.get()));
257
258 ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
259
260 ASSERT_OK(UnregisterExtensionType("uuid"));
261 auto ext_metadata =
262 key_value_metadata({{"ARROW:extension:name", "uuid"},
263 {"ARROW:extension:metadata", "uuid-serialized"}});
264 auto ext_field = field("f0", fixed_size_binary(16), true, ext_metadata);
265 auto batch_no_ext = RecordBatch::Make(schema({ext_field}), 4, {storage_arr});
266
267 io::BufferReader reader(complete_ipc_stream);
268 std::shared_ptr<RecordBatchReader> batch_reader;
269 ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
270 std::shared_ptr<RecordBatch> read_batch;
271 ASSERT_OK(batch_reader->ReadNext(&read_batch));
272 CompareBatch(*batch_no_ext, *read_batch);
273 }
274
ExampleParametric(std::shared_ptr<DataType> type,const std::string & json_data)275 std::shared_ptr<Array> ExampleParametric(std::shared_ptr<DataType> type,
276 const std::string& json_data) {
277 auto arr = ArrayFromJSON(int32(), json_data);
278 auto ext_data = arr->data()->Copy();
279 ext_data->type = type;
280 return MakeArray(ext_data);
281 }
282
TEST_F(TestExtensionType,ParametricTypes)283 TEST_F(TestExtensionType, ParametricTypes) {
284 auto p1_type = std::make_shared<Parametric1Type>(6);
285 auto p1 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
286
287 auto p2_type = std::make_shared<Parametric1Type>(12);
288 auto p2 = ExampleParametric(p2_type, "[2, null, 3, 4]");
289
290 auto p3_type = std::make_shared<Parametric2Type>(2);
291 auto p3 = ExampleParametric(p3_type, "[5, 6, 7, 8]");
292
293 auto p4_type = std::make_shared<Parametric2Type>(3);
294 auto p4 = ExampleParametric(p4_type, "[5, 6, 7, 9]");
295
296 ASSERT_OK(RegisterExtensionType(std::make_shared<Parametric1Type>(-1)));
297 ASSERT_OK(RegisterExtensionType(p3_type));
298 ASSERT_OK(RegisterExtensionType(p4_type));
299
300 auto batch = RecordBatch::Make(schema({field("f0", p1_type), field("f1", p2_type),
301 field("f2", p3_type), field("f3", p4_type)}),
302 4, {p1, p2, p3, p4});
303
304 std::shared_ptr<RecordBatch> read_batch;
305 RoundtripBatch(batch, &read_batch);
306 CompareBatch(*batch, *read_batch, false /* compare_metadata */);
307 }
308
TEST_F(TestExtensionType,ParametricEquals)309 TEST_F(TestExtensionType, ParametricEquals) {
310 auto p1_type = std::make_shared<Parametric1Type>(6);
311 auto p2_type = std::make_shared<Parametric1Type>(6);
312 auto p3_type = std::make_shared<Parametric1Type>(3);
313
314 ASSERT_TRUE(p1_type->Equals(p2_type));
315 ASSERT_FALSE(p1_type->Equals(p3_type));
316
317 ASSERT_EQ(p1_type->fingerprint(), "");
318 }
319
ExampleStruct()320 std::shared_ptr<Array> ExampleStruct() {
321 auto ext_type = std::make_shared<ExtStructType>();
322 auto storage_type = ext_type->storage_type();
323 auto arr = ArrayFromJSON(storage_type, "[[1, 0.1], [2, 0.2]]");
324
325 auto ext_data = arr->data()->Copy();
326 ext_data->type = ext_type;
327 return MakeArray(ext_data);
328 }
329
TEST_F(TestExtensionType,ValidateExtensionArray)330 TEST_F(TestExtensionType, ValidateExtensionArray) {
331 auto ext_arr1 = ExampleUuid();
332 auto p1_type = std::make_shared<Parametric1Type>(6);
333 auto ext_arr2 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
334 auto ext_arr3 = ExampleStruct();
335
336 ASSERT_OK(ext_arr1->ValidateFull());
337 ASSERT_OK(ext_arr2->ValidateFull());
338 ASSERT_OK(ext_arr3->ValidateFull());
339 }
340
341 } // namespace arrow
342