1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <algorithm>
19 #include <array>
20 #include <cstdint>
21 #include <cstring>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <numeric>
26 #include <sstream>
27 #include <string>
28 #include <type_traits>
29 #include <vector>
30 
31 #include <gtest/gtest.h>
32 
33 #include "arrow/array.h"
34 #include "arrow/buffer.h"
35 #include "arrow/buffer_builder.h"
36 #include "arrow/extension_type.h"
37 #include "arrow/io/memory.h"
38 #include "arrow/ipc/reader.h"
39 #include "arrow/ipc/writer.h"
40 #include "arrow/record_batch.h"
41 #include "arrow/status.h"
42 #include "arrow/testing/extension_type.h"
43 #include "arrow/testing/gtest_common.h"
44 #include "arrow/testing/util.h"
45 #include "arrow/type.h"
46 #include "arrow/util/key_value_metadata.h"
47 #include "arrow/util/logging.h"
48 
49 namespace arrow {
50 
51 class Parametric1Array : public ExtensionArray {
52  public:
53   using ExtensionArray::ExtensionArray;
54 };
55 
56 class Parametric2Array : public ExtensionArray {
57  public:
58   using ExtensionArray::ExtensionArray;
59 };
60 
61 // A parametric type where the extension_name() is always the same
62 class Parametric1Type : public ExtensionType {
63  public:
Parametric1Type(int32_t parameter)64   explicit Parametric1Type(int32_t parameter)
65       : ExtensionType(int32()), parameter_(parameter) {}
66 
parameter() const67   int32_t parameter() const { return parameter_; }
68 
extension_name() const69   std::string extension_name() const override { return "parametric-type-1"; }
70 
ExtensionEquals(const ExtensionType & other) const71   bool ExtensionEquals(const ExtensionType& other) const override {
72     const auto& other_ext = static_cast<const ExtensionType&>(other);
73     if (other_ext.extension_name() != this->extension_name()) {
74       return false;
75     }
76     return this->parameter() == static_cast<const Parametric1Type&>(other).parameter();
77   }
78 
MakeArray(std::shared_ptr<ArrayData> data) const79   std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
80     return std::make_shared<Parametric1Array>(data);
81   }
82 
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const83   Result<std::shared_ptr<DataType>> Deserialize(
84       std::shared_ptr<DataType> storage_type,
85       const std::string& serialized) const override {
86     DCHECK_EQ(4, serialized.size());
87     const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
88     DCHECK(storage_type->Equals(int32()));
89     return std::make_shared<Parametric1Type>(parameter);
90   }
91 
Serialize() const92   std::string Serialize() const override {
93     std::string result("    ");
94     memcpy(&result[0], &parameter_, sizeof(int32_t));
95     return result;
96   }
97 
98  private:
99   int32_t parameter_;
100 };
101 
102 // A parametric type where the extension_name() is different for each
103 // parameter, and must be separately registered
104 class Parametric2Type : public ExtensionType {
105  public:
Parametric2Type(int32_t parameter)106   explicit Parametric2Type(int32_t parameter)
107       : ExtensionType(int32()), parameter_(parameter) {}
108 
parameter() const109   int32_t parameter() const { return parameter_; }
110 
extension_name() const111   std::string extension_name() const override {
112     std::stringstream ss;
113     ss << "parametric-type-2<param=" << parameter_ << ">";
114     return ss.str();
115   }
116 
ExtensionEquals(const ExtensionType & other) const117   bool ExtensionEquals(const ExtensionType& other) const override {
118     const auto& other_ext = static_cast<const ExtensionType&>(other);
119     if (other_ext.extension_name() != this->extension_name()) {
120       return false;
121     }
122     return this->parameter() == static_cast<const Parametric2Type&>(other).parameter();
123   }
124 
MakeArray(std::shared_ptr<ArrayData> data) const125   std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
126     return std::make_shared<Parametric2Array>(data);
127   }
128 
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const129   Result<std::shared_ptr<DataType>> Deserialize(
130       std::shared_ptr<DataType> storage_type,
131       const std::string& serialized) const override {
132     DCHECK_EQ(4, serialized.size());
133     const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
134     DCHECK(storage_type->Equals(int32()));
135     return std::make_shared<Parametric2Type>(parameter);
136   }
137 
Serialize() const138   std::string Serialize() const override {
139     std::string result("    ");
140     memcpy(&result[0], &parameter_, sizeof(int32_t));
141     return result;
142   }
143 
144  private:
145   int32_t parameter_;
146 };
147 
148 // An extension type with a non-primitive storage type
149 class ExtStructArray : public ExtensionArray {
150  public:
151   using ExtensionArray::ExtensionArray;
152 };
153 
154 class ExtStructType : public ExtensionType {
155  public:
ExtStructType()156   ExtStructType()
157       : ExtensionType(
158             struct_({::arrow::field("a", int64()), ::arrow::field("b", float64())})) {}
159 
extension_name() const160   std::string extension_name() const override { return "ext-struct-type"; }
161 
ExtensionEquals(const ExtensionType & other) const162   bool ExtensionEquals(const ExtensionType& other) const override {
163     const auto& other_ext = static_cast<const ExtensionType&>(other);
164     if (other_ext.extension_name() != this->extension_name()) {
165       return false;
166     }
167     return true;
168   }
169 
MakeArray(std::shared_ptr<ArrayData> data) const170   std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
171     return std::make_shared<ExtStructArray>(data);
172   }
173 
Deserialize(std::shared_ptr<DataType> storage_type,const std::string & serialized) const174   Result<std::shared_ptr<DataType>> Deserialize(
175       std::shared_ptr<DataType> storage_type,
176       const std::string& serialized) const override {
177     if (serialized != "ext-struct-type-unique-code") {
178       return Status::Invalid("Type identifier did not match");
179     }
180     return std::make_shared<ExtStructType>();
181   }
182 
Serialize() const183   std::string Serialize() const override { return "ext-struct-type-unique-code"; }
184 };
185 
186 class TestExtensionType : public ::testing::Test {
187  public:
SetUp()188   void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<UuidType>())); }
189 
TearDown()190   void TearDown() {
191     if (GetExtensionType("uuid")) {
192       ASSERT_OK(UnregisterExtensionType("uuid"));
193     }
194   }
195 };
196 
TEST_F(TestExtensionType,ExtensionTypeTest)197 TEST_F(TestExtensionType, ExtensionTypeTest) {
198   auto type_not_exist = GetExtensionType("uuid-unknown");
199   ASSERT_EQ(type_not_exist, nullptr);
200 
201   auto registered_type = GetExtensionType("uuid");
202   ASSERT_NE(registered_type, nullptr);
203 
204   auto type = uuid();
205   ASSERT_EQ(type->id(), Type::EXTENSION);
206 
207   const auto& ext_type = static_cast<const ExtensionType&>(*type);
208   std::string serialized = ext_type.Serialize();
209 
210   ASSERT_OK_AND_ASSIGN(auto deserialized,
211                        ext_type.Deserialize(fixed_size_binary(16), serialized));
212   ASSERT_TRUE(deserialized->Equals(*type));
213   ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
214 }
215 
216 auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
__anon50113e460102(const std::shared_ptr<RecordBatch>& batch, std::shared_ptr<RecordBatch>* out) 217                          std::shared_ptr<RecordBatch>* out) {
218   ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
219   ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
220                                         out_stream.get()));
221 
222   ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
223 
224   io::BufferReader reader(complete_ipc_stream);
225   std::shared_ptr<RecordBatchReader> batch_reader;
226   ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
227   ASSERT_OK(batch_reader->ReadNext(out));
228 };
229 
TEST_F(TestExtensionType,IpcRoundtrip)230 TEST_F(TestExtensionType, IpcRoundtrip) {
231   auto ext_arr = ExampleUuid();
232   auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
233 
234   std::shared_ptr<RecordBatch> read_batch;
235   RoundtripBatch(batch, &read_batch);
236   CompareBatch(*batch, *read_batch, false /* compare_metadata */);
237 
238   // Wrap type in a ListArray and ensure it also makes it
239   auto offsets_arr = ArrayFromJSON(int32(), "[0, 0, 2, 4]");
240   ASSERT_OK_AND_ASSIGN(auto list_arr, ListArray::FromArrays(*offsets_arr, *ext_arr));
241   batch = RecordBatch::Make(schema({field("f0", list(uuid()))}), 3, {list_arr});
242   RoundtripBatch(batch, &read_batch);
243   CompareBatch(*batch, *read_batch, false /* compare_metadata */);
244 }
245 
TEST_F(TestExtensionType,UnrecognizedExtension)246 TEST_F(TestExtensionType, UnrecognizedExtension) {
247   auto ext_arr = ExampleUuid();
248   auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
249 
250   auto storage_arr = static_cast<const ExtensionArray&>(*ext_arr).storage();
251 
252   // Write full IPC stream including schema, then unregister type, then read
253   // and ensure that a plain instance of the storage type is created
254   ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
255   ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
256                                         out_stream.get()));
257 
258   ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
259 
260   ASSERT_OK(UnregisterExtensionType("uuid"));
261   auto ext_metadata =
262       key_value_metadata({{"ARROW:extension:name", "uuid"},
263                           {"ARROW:extension:metadata", "uuid-serialized"}});
264   auto ext_field = field("f0", fixed_size_binary(16), true, ext_metadata);
265   auto batch_no_ext = RecordBatch::Make(schema({ext_field}), 4, {storage_arr});
266 
267   io::BufferReader reader(complete_ipc_stream);
268   std::shared_ptr<RecordBatchReader> batch_reader;
269   ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
270   std::shared_ptr<RecordBatch> read_batch;
271   ASSERT_OK(batch_reader->ReadNext(&read_batch));
272   CompareBatch(*batch_no_ext, *read_batch);
273 }
274 
ExampleParametric(std::shared_ptr<DataType> type,const std::string & json_data)275 std::shared_ptr<Array> ExampleParametric(std::shared_ptr<DataType> type,
276                                          const std::string& json_data) {
277   auto arr = ArrayFromJSON(int32(), json_data);
278   auto ext_data = arr->data()->Copy();
279   ext_data->type = type;
280   return MakeArray(ext_data);
281 }
282 
TEST_F(TestExtensionType,ParametricTypes)283 TEST_F(TestExtensionType, ParametricTypes) {
284   auto p1_type = std::make_shared<Parametric1Type>(6);
285   auto p1 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
286 
287   auto p2_type = std::make_shared<Parametric1Type>(12);
288   auto p2 = ExampleParametric(p2_type, "[2, null, 3, 4]");
289 
290   auto p3_type = std::make_shared<Parametric2Type>(2);
291   auto p3 = ExampleParametric(p3_type, "[5, 6, 7, 8]");
292 
293   auto p4_type = std::make_shared<Parametric2Type>(3);
294   auto p4 = ExampleParametric(p4_type, "[5, 6, 7, 9]");
295 
296   ASSERT_OK(RegisterExtensionType(std::make_shared<Parametric1Type>(-1)));
297   ASSERT_OK(RegisterExtensionType(p3_type));
298   ASSERT_OK(RegisterExtensionType(p4_type));
299 
300   auto batch = RecordBatch::Make(schema({field("f0", p1_type), field("f1", p2_type),
301                                          field("f2", p3_type), field("f3", p4_type)}),
302                                  4, {p1, p2, p3, p4});
303 
304   std::shared_ptr<RecordBatch> read_batch;
305   RoundtripBatch(batch, &read_batch);
306   CompareBatch(*batch, *read_batch, false /* compare_metadata */);
307 }
308 
TEST_F(TestExtensionType,ParametricEquals)309 TEST_F(TestExtensionType, ParametricEquals) {
310   auto p1_type = std::make_shared<Parametric1Type>(6);
311   auto p2_type = std::make_shared<Parametric1Type>(6);
312   auto p3_type = std::make_shared<Parametric1Type>(3);
313 
314   ASSERT_TRUE(p1_type->Equals(p2_type));
315   ASSERT_FALSE(p1_type->Equals(p3_type));
316 
317   ASSERT_EQ(p1_type->fingerprint(), "");
318 }
319 
ExampleStruct()320 std::shared_ptr<Array> ExampleStruct() {
321   auto ext_type = std::make_shared<ExtStructType>();
322   auto storage_type = ext_type->storage_type();
323   auto arr = ArrayFromJSON(storage_type, "[[1, 0.1], [2, 0.2]]");
324 
325   auto ext_data = arr->data()->Copy();
326   ext_data->type = ext_type;
327   return MakeArray(ext_data);
328 }
329 
TEST_F(TestExtensionType,ValidateExtensionArray)330 TEST_F(TestExtensionType, ValidateExtensionArray) {
331   auto ext_arr1 = ExampleUuid();
332   auto p1_type = std::make_shared<Parametric1Type>(6);
333   auto ext_arr2 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
334   auto ext_arr3 = ExampleStruct();
335 
336   ASSERT_OK(ext_arr1->ValidateFull());
337   ASSERT_OK(ext_arr2->ValidateFull());
338   ASSERT_OK(ext_arr3->ValidateFull());
339 }
340 
341 }  // namespace arrow
342