1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/ipc/reader.h"
19
20 #include <algorithm>
21 #include <climits>
22 #include <cstdint>
23 #include <cstring>
24 #include <string>
25 #include <type_traits>
26 #include <utility>
27 #include <vector>
28
29 #include <flatbuffers/flatbuffers.h> // IWYU pragma: export
30
31 #include "arrow/array.h"
32 #include "arrow/buffer.h"
33 #include "arrow/extension_type.h"
34 #include "arrow/io/interfaces.h"
35 #include "arrow/io/memory.h"
36 #include "arrow/ipc/message.h"
37 #include "arrow/ipc/metadata_internal.h"
38 #include "arrow/ipc/writer.h"
39 #include "arrow/record_batch.h"
40 #include "arrow/sparse_tensor.h"
41 #include "arrow/status.h"
42 #include "arrow/type.h"
43 #include "arrow/type_traits.h"
44 #include "arrow/util/bit_util.h"
45 #include "arrow/util/checked_cast.h"
46 #include "arrow/util/compression.h"
47 #include "arrow/util/key_value_metadata.h"
48 #include "arrow/util/logging.h"
49 #include "arrow/util/parallel.h"
50 #include "arrow/util/ubsan.h"
51 #include "arrow/visitor_inline.h"
52
53 #include "generated/File_generated.h" // IWYU pragma: export
54 #include "generated/Message_generated.h"
55 #include "generated/Schema_generated.h"
56 #include "generated/SparseTensor_generated.h"
57
58 namespace arrow {
59
60 namespace flatbuf = org::apache::arrow::flatbuf;
61
62 using internal::checked_cast;
63 using internal::checked_pointer_cast;
64
65 namespace ipc {
66
67 using internal::FileBlock;
68 using internal::kArrowMagicBytes;
69
70 namespace {
71
InvalidMessageType(Message::Type expected,Message::Type actual)72 Status InvalidMessageType(Message::Type expected, Message::Type actual) {
73 return Status::IOError("Expected IPC message of type ", FormatMessageType(expected),
74 " but got ", FormatMessageType(actual));
75 }
76
77 #define CHECK_MESSAGE_TYPE(expected, actual) \
78 do { \
79 if ((actual) != (expected)) { \
80 return InvalidMessageType((expected), (actual)); \
81 } \
82 } while (0)
83
84 #define CHECK_HAS_BODY(message) \
85 do { \
86 if ((message).body() == nullptr) { \
87 return Status::IOError("Expected body in IPC message of type ", \
88 FormatMessageType((message).type())); \
89 } \
90 } while (0)
91
92 #define CHECK_HAS_NO_BODY(message) \
93 do { \
94 if ((message).body_length() != 0) { \
95 return Status::IOError("Unexpected body in IPC message of type ", \
96 FormatMessageType((message).type())); \
97 } \
98 } while (0)
99
100 } // namespace
101
102 // ----------------------------------------------------------------------
103 // Record batch read path
104
105 /// The field_index and buffer_index are incremented based on how much of the
106 /// batch is "consumed" (through nested data reconstruction, for example)
107 class ArrayLoader {
108 public:
ArrayLoader(const flatbuf::RecordBatch * metadata,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)109 explicit ArrayLoader(const flatbuf::RecordBatch* metadata,
110 const DictionaryMemo* dictionary_memo,
111 const IpcReadOptions& options, io::RandomAccessFile* file)
112 : metadata_(metadata),
113 file_(file),
114 dictionary_memo_(dictionary_memo),
115 max_recursion_depth_(options.max_recursion_depth) {}
116
ReadBuffer(int64_t offset,int64_t length,std::shared_ptr<Buffer> * out)117 Status ReadBuffer(int64_t offset, int64_t length, std::shared_ptr<Buffer>* out) {
118 if (skip_io_) {
119 return Status::OK();
120 }
121 // This construct permits overriding GetBuffer at compile time
122 if (!BitUtil::IsMultipleOf8(offset)) {
123 return Status::Invalid("Buffer ", buffer_index_,
124 " did not start on 8-byte aligned offset: ", offset);
125 }
126 return file_->ReadAt(offset, length).Value(out);
127 }
128
LoadType(const DataType & type)129 Status LoadType(const DataType& type) { return VisitTypeInline(type, this); }
130
Load(const Field * field,ArrayData * out)131 Status Load(const Field* field, ArrayData* out) {
132 if (max_recursion_depth_ <= 0) {
133 return Status::Invalid("Max recursion depth reached");
134 }
135
136 field_ = field;
137 out_ = out;
138 out_->type = field_->type();
139 return LoadType(*field_->type());
140 }
141
SkipField(const Field * field)142 Status SkipField(const Field* field) {
143 ArrayData dummy;
144 skip_io_ = true;
145 Status status = Load(field, &dummy);
146 skip_io_ = false;
147 return status;
148 }
149
GetBuffer(int buffer_index,std::shared_ptr<Buffer> * out)150 Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) {
151 auto buffers = metadata_->buffers();
152 CHECK_FLATBUFFERS_NOT_NULL(buffers, "RecordBatch.buffers");
153 if (buffer_index >= static_cast<int>(buffers->size())) {
154 return Status::IOError("buffer_index out of range.");
155 }
156 const flatbuf::Buffer* buffer = buffers->Get(buffer_index);
157 if (buffer->length() == 0) {
158 // Should never return a null buffer here.
159 // (zero-sized buffer allocations are cheap)
160 return AllocateBuffer(0).Value(out);
161 } else {
162 return ReadBuffer(buffer->offset(), buffer->length(), out);
163 }
164 }
165
GetFieldMetadata(int field_index,ArrayData * out)166 Status GetFieldMetadata(int field_index, ArrayData* out) {
167 auto nodes = metadata_->nodes();
168 CHECK_FLATBUFFERS_NOT_NULL(nodes, "Table.nodes");
169 // pop off a field
170 if (field_index >= static_cast<int>(nodes->size())) {
171 return Status::Invalid("Ran out of field metadata, likely malformed");
172 }
173 const flatbuf::FieldNode* node = nodes->Get(field_index);
174
175 out->length = node->length();
176 out->null_count = node->null_count();
177 out->offset = 0;
178 return Status::OK();
179 }
180
LoadCommon()181 Status LoadCommon() {
182 // This only contains the length and null count, which we need to figure
183 // out what to do with the buffers. For example, if null_count == 0, then
184 // we can skip that buffer without reading from shared memory
185 RETURN_NOT_OK(GetFieldMetadata(field_index_++, out_));
186
187 // extract null_bitmap which is common to all arrays
188 if (out_->null_count == 0) {
189 out_->buffers[0] = nullptr;
190 } else {
191 RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[0]));
192 }
193 buffer_index_++;
194 return Status::OK();
195 }
196
197 template <typename TYPE>
LoadPrimitive()198 Status LoadPrimitive() {
199 out_->buffers.resize(2);
200
201 RETURN_NOT_OK(LoadCommon());
202 if (out_->length > 0) {
203 RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
204 } else {
205 buffer_index_++;
206 out_->buffers[1].reset(new Buffer(nullptr, 0));
207 }
208 return Status::OK();
209 }
210
211 template <typename TYPE>
LoadBinary()212 Status LoadBinary() {
213 out_->buffers.resize(3);
214
215 RETURN_NOT_OK(LoadCommon());
216 RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
217 return GetBuffer(buffer_index_++, &out_->buffers[2]);
218 }
219
220 template <typename TYPE>
LoadList(const TYPE & type)221 Status LoadList(const TYPE& type) {
222 out_->buffers.resize(2);
223
224 RETURN_NOT_OK(LoadCommon());
225 RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
226
227 const int num_children = type.num_fields();
228 if (num_children != 1) {
229 return Status::Invalid("Wrong number of children: ", num_children);
230 }
231
232 return LoadChildren(type.fields());
233 }
234
LoadChildren(std::vector<std::shared_ptr<Field>> child_fields)235 Status LoadChildren(std::vector<std::shared_ptr<Field>> child_fields) {
236 ArrayData* parent = out_;
237 parent->child_data.reserve(static_cast<int>(child_fields.size()));
238 for (const auto& child_field : child_fields) {
239 auto field_array = std::make_shared<ArrayData>();
240 --max_recursion_depth_;
241 RETURN_NOT_OK(Load(child_field.get(), field_array.get()));
242 ++max_recursion_depth_;
243 parent->child_data.emplace_back(field_array);
244 }
245 out_ = parent;
246 return Status::OK();
247 }
248
Visit(const NullType & type)249 Status Visit(const NullType& type) {
250 out_->buffers.resize(1);
251
252 // ARROW-6379: NullType has no buffers in the IPC payload
253 return GetFieldMetadata(field_index_++, out_);
254 }
255
256 template <typename T>
257 enable_if_t<std::is_base_of<FixedWidthType, T>::value &&
258 !std::is_base_of<FixedSizeBinaryType, T>::value &&
259 !std::is_base_of<DictionaryType, T>::value,
260 Status>
Visit(const T & type)261 Visit(const T& type) {
262 return LoadPrimitive<T>();
263 }
264
265 template <typename T>
Visit(const T & type)266 enable_if_base_binary<T, Status> Visit(const T& type) {
267 return LoadBinary<T>();
268 }
269
Visit(const FixedSizeBinaryType & type)270 Status Visit(const FixedSizeBinaryType& type) {
271 out_->buffers.resize(2);
272 RETURN_NOT_OK(LoadCommon());
273 return GetBuffer(buffer_index_++, &out_->buffers[1]);
274 }
275
276 template <typename T>
Visit(const T & type)277 enable_if_var_size_list<T, Status> Visit(const T& type) {
278 return LoadList(type);
279 }
280
Visit(const MapType & type)281 Status Visit(const MapType& type) {
282 RETURN_NOT_OK(LoadList(type));
283 return MapArray::ValidateChildData(out_->child_data);
284 }
285
Visit(const FixedSizeListType & type)286 Status Visit(const FixedSizeListType& type) {
287 out_->buffers.resize(1);
288
289 RETURN_NOT_OK(LoadCommon());
290
291 const int num_children = type.num_fields();
292 if (num_children != 1) {
293 return Status::Invalid("Wrong number of children: ", num_children);
294 }
295
296 return LoadChildren(type.fields());
297 }
298
Visit(const StructType & type)299 Status Visit(const StructType& type) {
300 out_->buffers.resize(1);
301 RETURN_NOT_OK(LoadCommon());
302 return LoadChildren(type.fields());
303 }
304
Visit(const UnionType & type)305 Status Visit(const UnionType& type) {
306 out_->buffers.resize(3);
307
308 RETURN_NOT_OK(LoadCommon());
309 if (out_->length > 0) {
310 RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[1]));
311 if (type.mode() == UnionMode::DENSE) {
312 RETURN_NOT_OK(GetBuffer(buffer_index_ + 1, &out_->buffers[2]));
313 }
314 }
315 buffer_index_ += type.mode() == UnionMode::DENSE ? 2 : 1;
316 return LoadChildren(type.fields());
317 }
318
Visit(const DictionaryType & type)319 Status Visit(const DictionaryType& type) {
320 RETURN_NOT_OK(LoadType(*type.index_type()));
321
322 // Look up dictionary
323 int64_t id = -1;
324 RETURN_NOT_OK(dictionary_memo_->GetId(field_, &id));
325 RETURN_NOT_OK(dictionary_memo_->GetDictionary(id, &out_->dictionary));
326
327 return Status::OK();
328 }
329
Visit(const ExtensionType & type)330 Status Visit(const ExtensionType& type) { return LoadType(*type.storage_type()); }
331
332 private:
333 const flatbuf::RecordBatch* metadata_;
334 io::RandomAccessFile* file_;
335 const DictionaryMemo* dictionary_memo_;
336 int max_recursion_depth_;
337 int buffer_index_ = 0;
338 int field_index_ = 0;
339 bool skip_io_ = false;
340
341 const Field* field_;
342 ArrayData* out_;
343 };
344
DecompressBuffers(Compression::type compression,const IpcReadOptions & options,std::vector<std::shared_ptr<ArrayData>> * fields)345 Status DecompressBuffers(Compression::type compression, const IpcReadOptions& options,
346 std::vector<std::shared_ptr<ArrayData>>* fields) {
347 std::unique_ptr<util::Codec> codec;
348 ARROW_ASSIGN_OR_RAISE(codec, util::Codec::Create(compression));
349
350 auto DecompressOne = [&](int i) {
351 ArrayData* arr = (*fields)[i].get();
352 for (size_t i = 0; i < arr->buffers.size(); ++i) {
353 if (arr->buffers[i] == nullptr) {
354 continue;
355 }
356 if (arr->buffers[i]->size() == 0) {
357 continue;
358 }
359 if (arr->buffers[i]->size() < 8) {
360 return Status::Invalid(
361 "Likely corrupted message, compressed buffers "
362 "are larger than 8 bytes by construction");
363 }
364 const uint8_t* data = arr->buffers[i]->data();
365 int64_t compressed_size = arr->buffers[i]->size() - sizeof(int64_t);
366 int64_t uncompressed_size =
367 BitUtil::FromLittleEndian(util::SafeLoadAs<int64_t>(data));
368
369 ARROW_ASSIGN_OR_RAISE(auto uncompressed,
370 AllocateBuffer(uncompressed_size, options.memory_pool));
371
372 int64_t actual_decompressed;
373 ARROW_ASSIGN_OR_RAISE(
374 actual_decompressed,
375 codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size,
376 uncompressed->mutable_data()));
377 if (actual_decompressed != uncompressed_size) {
378 return Status::Invalid("Failed to fully decompress buffer, expected ",
379 uncompressed_size, " bytes but decompressed ",
380 actual_decompressed);
381 }
382 arr->buffers[i] = std::move(uncompressed);
383 }
384 return Status::OK();
385 };
386
387 return ::arrow::internal::OptionalParallelFor(
388 options.use_threads, static_cast<int>(fields->size()), DecompressOne);
389 }
390
LoadRecordBatchSubset(const flatbuf::RecordBatch * metadata,const std::shared_ptr<Schema> & schema,const std::vector<bool> & inclusion_mask,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,Compression::type compression,io::RandomAccessFile * file)391 Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
392 const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
393 const std::vector<bool>& inclusion_mask, const DictionaryMemo* dictionary_memo,
394 const IpcReadOptions& options, Compression::type compression,
395 io::RandomAccessFile* file) {
396 ArrayLoader loader(metadata, dictionary_memo, options, file);
397
398 std::vector<std::shared_ptr<ArrayData>> field_data;
399 std::vector<std::shared_ptr<Field>> schema_fields;
400
401 for (int i = 0; i < schema->num_fields(); ++i) {
402 if (inclusion_mask[i]) {
403 // Read field
404 auto arr = std::make_shared<ArrayData>();
405 RETURN_NOT_OK(loader.Load(schema->field(i).get(), arr.get()));
406 if (metadata->length() != arr->length) {
407 return Status::IOError("Array length did not match record batch length");
408 }
409 field_data.emplace_back(std::move(arr));
410 schema_fields.emplace_back(schema->field(i));
411 } else {
412 // Skip field. This logic must be executed to advance the state of the
413 // loader to the next field
414 RETURN_NOT_OK(loader.SkipField(schema->field(i).get()));
415 }
416 }
417
418 if (compression != Compression::UNCOMPRESSED) {
419 RETURN_NOT_OK(DecompressBuffers(compression, options, &field_data));
420 }
421
422 return RecordBatch::Make(::arrow::schema(std::move(schema_fields), schema->metadata()),
423 metadata->length(), std::move(field_data));
424 }
425
LoadRecordBatch(const flatbuf::RecordBatch * metadata,const std::shared_ptr<Schema> & schema,const std::vector<bool> & inclusion_mask,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,Compression::type compression,io::RandomAccessFile * file)426 Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
427 const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
428 const std::vector<bool>& inclusion_mask, const DictionaryMemo* dictionary_memo,
429 const IpcReadOptions& options, Compression::type compression,
430 io::RandomAccessFile* file) {
431 if (inclusion_mask.size() > 0) {
432 return LoadRecordBatchSubset(metadata, schema, inclusion_mask, dictionary_memo,
433 options, compression, file);
434 }
435
436 ArrayLoader loader(metadata, dictionary_memo, options, file);
437 std::vector<std::shared_ptr<ArrayData>> arrays(schema->num_fields());
438 for (int i = 0; i < schema->num_fields(); ++i) {
439 auto arr = std::make_shared<ArrayData>();
440 RETURN_NOT_OK(loader.Load(schema->field(i).get(), arr.get()));
441 if (metadata->length() != arr->length) {
442 return Status::IOError("Array length did not match record batch length");
443 }
444 arrays[i] = std::move(arr);
445 }
446 if (compression != Compression::UNCOMPRESSED) {
447 RETURN_NOT_OK(DecompressBuffers(compression, options, &arrays));
448 }
449 return RecordBatch::Make(schema, metadata->length(), std::move(arrays));
450 }
451
452 // ----------------------------------------------------------------------
453 // Array loading
454
GetCompression(const flatbuf::Message * message,Compression::type * out)455 Status GetCompression(const flatbuf::Message* message, Compression::type* out) {
456 *out = Compression::UNCOMPRESSED;
457 if (message->custom_metadata() != nullptr) {
458 // TODO: Ensure this deserialization only ever happens once
459 std::shared_ptr<KeyValueMetadata> metadata;
460 RETURN_NOT_OK(internal::GetKeyValueMetadata(message->custom_metadata(), &metadata));
461 int index = metadata->FindKey("ARROW:experimental_compression");
462 if (index != -1) {
463 ARROW_ASSIGN_OR_RAISE(*out,
464 util::Codec::GetCompressionType(metadata->value(index)));
465 }
466 RETURN_NOT_OK(internal::CheckCompressionSupported(*out));
467 }
468 return Status::OK();
469 }
470
ReadContiguousPayload(io::InputStream * file,std::unique_ptr<Message> * message)471 static Status ReadContiguousPayload(io::InputStream* file,
472 std::unique_ptr<Message>* message) {
473 ARROW_ASSIGN_OR_RAISE(*message, ReadMessage(file));
474 if (*message == nullptr) {
475 return Status::Invalid("Unable to read metadata at offset");
476 }
477 return Status::OK();
478 }
479
ReadRecordBatch(const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::InputStream * file)480 Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
481 const std::shared_ptr<Schema>& schema, const DictionaryMemo* dictionary_memo,
482 const IpcReadOptions& options, io::InputStream* file) {
483 std::unique_ptr<Message> message;
484 RETURN_NOT_OK(ReadContiguousPayload(file, &message));
485 CHECK_HAS_BODY(*message);
486 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
487 return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options,
488 reader.get());
489 }
490
ReadRecordBatch(const Message & message,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options)491 Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
492 const Message& message, const std::shared_ptr<Schema>& schema,
493 const DictionaryMemo* dictionary_memo, const IpcReadOptions& options) {
494 CHECK_MESSAGE_TYPE(Message::RECORD_BATCH, message.type());
495 CHECK_HAS_BODY(message);
496 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
497 return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options,
498 reader.get());
499 }
500
ReadRecordBatchInternal(const Buffer & metadata,const std::shared_ptr<Schema> & schema,const std::vector<bool> & inclusion_mask,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)501 Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal(
502 const Buffer& metadata, const std::shared_ptr<Schema>& schema,
503 const std::vector<bool>& inclusion_mask, const DictionaryMemo* dictionary_memo,
504 const IpcReadOptions& options, io::RandomAccessFile* file) {
505 const flatbuf::Message* message = nullptr;
506 RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
507 auto batch = message->header_as_RecordBatch();
508 if (batch == nullptr) {
509 return Status::IOError(
510 "Header-type of flatbuffer-encoded Message is not RecordBatch.");
511 }
512 Compression::type compression;
513 RETURN_NOT_OK(GetCompression(message, &compression));
514 return LoadRecordBatch(batch, schema, inclusion_mask, dictionary_memo, options,
515 compression, file);
516 }
517
518 // If we are selecting only certain fields, populate an inclusion mask for fast lookups.
519 // Additionally, drop deselected fields from the reader's schema.
GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema> & full_schema,const std::vector<int> & included_indices,std::vector<bool> * inclusion_mask,std::shared_ptr<Schema> * out_schema)520 Status GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema>& full_schema,
521 const std::vector<int>& included_indices,
522 std::vector<bool>* inclusion_mask,
523 std::shared_ptr<Schema>* out_schema) {
524 inclusion_mask->clear();
525 if (included_indices.empty()) {
526 *out_schema = full_schema;
527 return Status::OK();
528 }
529
530 inclusion_mask->resize(full_schema->num_fields(), false);
531
532 auto included_indices_sorted = included_indices;
533 std::sort(included_indices_sorted.begin(), included_indices_sorted.end());
534
535 FieldVector included_fields;
536 for (int i : included_indices_sorted) {
537 // Ignore out of bounds indices
538 if (i < 0 || i >= full_schema->num_fields()) {
539 return Status::Invalid("Out of bounds field index: ", i);
540 }
541
542 if (inclusion_mask->at(i)) continue;
543
544 inclusion_mask->at(i) = true;
545 included_fields.push_back(full_schema->field(i));
546 }
547
548 *out_schema = schema(std::move(included_fields), full_schema->metadata());
549 return Status::OK();
550 }
551
UnpackSchemaMessage(const void * opaque_schema,const IpcReadOptions & options,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * schema,std::shared_ptr<Schema> * out_schema,std::vector<bool> * field_inclusion_mask)552 Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& options,
553 DictionaryMemo* dictionary_memo,
554 std::shared_ptr<Schema>* schema,
555 std::shared_ptr<Schema>* out_schema,
556 std::vector<bool>* field_inclusion_mask) {
557 RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema));
558
559 // If we are selecting only certain fields, populate the inclusion mask now
560 // for fast lookups
561 return GetInclusionMaskAndOutSchema(*schema, options.included_fields,
562 field_inclusion_mask, out_schema);
563 }
564
UnpackSchemaMessage(const Message & message,const IpcReadOptions & options,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * schema,std::shared_ptr<Schema> * out_schema,std::vector<bool> * field_inclusion_mask)565 Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options,
566 DictionaryMemo* dictionary_memo,
567 std::shared_ptr<Schema>* schema,
568 std::shared_ptr<Schema>* out_schema,
569 std::vector<bool>* field_inclusion_mask) {
570 CHECK_MESSAGE_TYPE(Message::SCHEMA, message.type());
571 CHECK_HAS_NO_BODY(message);
572
573 return UnpackSchemaMessage(message.header(), options, dictionary_memo, schema,
574 out_schema, field_inclusion_mask);
575 }
576
ReadRecordBatch(const Buffer & metadata,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)577 Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
578 const Buffer& metadata, const std::shared_ptr<Schema>& schema,
579 const DictionaryMemo* dictionary_memo, const IpcReadOptions& options,
580 io::RandomAccessFile* file) {
581 std::shared_ptr<Schema> out_schema;
582 // Empty means do not use
583 std::vector<bool> inclusion_mask;
584 RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, options.included_fields,
585 &inclusion_mask, &out_schema));
586 return ReadRecordBatchInternal(metadata, schema, inclusion_mask, dictionary_memo,
587 options, file);
588 }
589
ReadDictionary(const Buffer & metadata,DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)590 Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo,
591 const IpcReadOptions& options, io::RandomAccessFile* file) {
592 const flatbuf::Message* message = nullptr;
593 RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
594 auto dictionary_batch = message->header_as_DictionaryBatch();
595 if (dictionary_batch == nullptr) {
596 return Status::IOError(
597 "Header-type of flatbuffer-encoded Message is not DictionaryBatch.");
598 }
599
600 Compression::type compression;
601 RETURN_NOT_OK(GetCompression(message, &compression));
602
603 int64_t id = dictionary_batch->id();
604
605 // Look up the field, which must have been added to the
606 // DictionaryMemo already prior to invoking this function
607 std::shared_ptr<DataType> value_type;
608 RETURN_NOT_OK(dictionary_memo->GetDictionaryType(id, &value_type));
609
610 auto value_field = ::arrow::field("dummy", value_type);
611
612 // The dictionary is embedded in a record batch with a single column
613 auto batch_meta = dictionary_batch->data();
614 CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data");
615
616 std::shared_ptr<RecordBatch> batch;
617 ARROW_ASSIGN_OR_RAISE(
618 batch, LoadRecordBatch(batch_meta, ::arrow::schema({value_field}),
619 /*field_inclusion_mask=*/{}, dictionary_memo, options,
620 compression, file));
621 if (batch->num_columns() != 1) {
622 return Status::Invalid("Dictionary record batch must only contain one field");
623 }
624 auto dictionary = batch->column(0);
625 return dictionary_memo->AddDictionary(id, dictionary);
626 }
627
ParseDictionary(const Message & message,DictionaryMemo * dictionary_memo,const IpcReadOptions & options)628 Status ParseDictionary(const Message& message, DictionaryMemo* dictionary_memo,
629 const IpcReadOptions& options) {
630 // Only invoke this method if we already know we have a dictionary message
631 DCHECK_EQ(message.type(), Message::DICTIONARY_BATCH);
632 CHECK_HAS_BODY(message);
633 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
634 return ReadDictionary(*message.metadata(), dictionary_memo, options, reader.get());
635 }
636
UpdateDictionaries(const Message & message,DictionaryMemo * dictionary_memo,const IpcReadOptions & options)637 Status UpdateDictionaries(const Message& message, DictionaryMemo* dictionary_memo,
638 const IpcReadOptions& options) {
639 // TODO(wesm): implement delta dictionaries
640 return Status::NotImplemented("Delta dictionaries not yet implemented");
641 }
642
643 // ----------------------------------------------------------------------
644 // RecordBatchStreamReader implementation
645
646 class RecordBatchStreamReaderImpl : public RecordBatchStreamReader {
647 public:
Open(std::unique_ptr<MessageReader> message_reader,const IpcReadOptions & options)648 Status Open(std::unique_ptr<MessageReader> message_reader,
649 const IpcReadOptions& options) {
650 message_reader_ = std::move(message_reader);
651 options_ = options;
652
653 // Read schema
654 ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
655 message_reader_->ReadNextMessage());
656 if (!message) {
657 return Status::Invalid("Tried reading schema message, was null or length 0");
658 }
659
660 return UnpackSchemaMessage(*message, options, &dictionary_memo_, &schema_,
661 &out_schema_, &field_inclusion_mask_);
662 }
663
ReadNext(std::shared_ptr<RecordBatch> * batch)664 Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
665 if (!have_read_initial_dictionaries_) {
666 RETURN_NOT_OK(ReadInitialDictionaries());
667 }
668
669 if (empty_stream_) {
670 // ARROW-6006: Degenerate case where stream contains no data, we do not
671 // bother trying to read a RecordBatch message from the stream
672 *batch = nullptr;
673 return Status::OK();
674 }
675
676 ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
677 message_reader_->ReadNextMessage());
678 if (message == nullptr) {
679 // End of stream
680 *batch = nullptr;
681 return Status::OK();
682 }
683
684 if (message->type() == Message::DICTIONARY_BATCH) {
685 return UpdateDictionaries(*message, &dictionary_memo_, options_);
686 } else {
687 CHECK_HAS_BODY(*message);
688 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
689 return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
690 &dictionary_memo_, options_, reader.get())
691 .Value(batch);
692 }
693 }
694
schema() const695 std::shared_ptr<Schema> schema() const override { return out_schema_; }
696
697 private:
ReadInitialDictionaries()698 Status ReadInitialDictionaries() {
699 // We must receive all dictionaries before reconstructing the
700 // first record batch. Subsequent dictionary deltas modify the memo
701 std::unique_ptr<Message> message;
702
703 // TODO(wesm): In future, we may want to reconcile the ids in the stream with
704 // those found in the schema
705 for (int i = 0; i < dictionary_memo_.num_fields(); ++i) {
706 ARROW_ASSIGN_OR_RAISE(message, message_reader_->ReadNextMessage());
707 if (!message) {
708 if (i == 0) {
709 /// ARROW-6006: If we fail to find any dictionaries in the stream, then
710 /// it may be that the stream has a schema but no actual data. In such
711 /// case we communicate that we were unable to find the dictionaries
712 /// (but there was no failure otherwise), so the caller can decide what
713 /// to do
714 empty_stream_ = true;
715 break;
716 } else {
717 // ARROW-6126, the stream terminated before receiving the expected
718 // number of dictionaries
719 return Status::Invalid("IPC stream ended without reading the expected number (",
720 dictionary_memo_.num_fields(), ") of dictionaries");
721 }
722 }
723
724 if (message->type() != Message::DICTIONARY_BATCH) {
725 return Status::Invalid("IPC stream did not have the expected number (",
726 dictionary_memo_.num_fields(),
727 ") of dictionaries at the start of the stream");
728 }
729 RETURN_NOT_OK(ParseDictionary(*message, &dictionary_memo_, options_));
730 }
731
732 have_read_initial_dictionaries_ = true;
733 return Status::OK();
734 }
735
736 std::unique_ptr<MessageReader> message_reader_;
737 IpcReadOptions options_;
738 std::vector<bool> field_inclusion_mask_;
739
740 bool have_read_initial_dictionaries_ = false;
741
742 // Flag to set in case where we fail to observe all dictionaries in a stream,
743 // and so the reader should not attempt to parse any messages
744 bool empty_stream_ = false;
745
746 DictionaryMemo dictionary_memo_;
747 std::shared_ptr<Schema> schema_, out_schema_;
748 };
749
750 // ----------------------------------------------------------------------
751 // Stream reader constructors
752
Open(std::unique_ptr<MessageReader> message_reader,const IpcReadOptions & options)753 Result<std::shared_ptr<RecordBatchReader>> RecordBatchStreamReader::Open(
754 std::unique_ptr<MessageReader> message_reader, const IpcReadOptions& options) {
755 // Private ctor
756 auto result = std::make_shared<RecordBatchStreamReaderImpl>();
757 RETURN_NOT_OK(result->Open(std::move(message_reader), options));
758 return result;
759 }
760
Open(io::InputStream * stream,const IpcReadOptions & options)761 Result<std::shared_ptr<RecordBatchReader>> RecordBatchStreamReader::Open(
762 io::InputStream* stream, const IpcReadOptions& options) {
763 return Open(MessageReader::Open(stream), options);
764 }
765
Open(const std::shared_ptr<io::InputStream> & stream,const IpcReadOptions & options)766 Result<std::shared_ptr<RecordBatchReader>> RecordBatchStreamReader::Open(
767 const std::shared_ptr<io::InputStream>& stream, const IpcReadOptions& options) {
768 return Open(MessageReader::Open(stream), options);
769 }
770
771 // ----------------------------------------------------------------------
772 // Reader implementation
773
FileBlockFromFlatbuffer(const flatbuf::Block * block)774 static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) {
775 return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()};
776 }
777
778 class RecordBatchFileReaderImpl : public RecordBatchFileReader {
779 public:
RecordBatchFileReaderImpl()780 RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) {}
781
num_record_batches() const782 int num_record_batches() const override {
783 return static_cast<int>(internal::FlatBuffersVectorSize(footer_->recordBatches()));
784 }
785
version() const786 MetadataVersion version() const override {
787 return internal::GetMetadataVersion(footer_->version());
788 }
789
ReadRecordBatch(int i)790 Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(int i) override {
791 DCHECK_GE(i, 0);
792 DCHECK_LT(i, num_record_batches());
793
794 if (!read_dictionaries_) {
795 RETURN_NOT_OK(ReadDictionaries());
796 read_dictionaries_ = true;
797 }
798
799 std::unique_ptr<Message> message;
800 RETURN_NOT_OK(ReadMessageFromBlock(GetRecordBatchBlock(i), &message));
801
802 CHECK_HAS_BODY(*message);
803 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
804 return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &dictionary_memo_,
805 options_, reader.get());
806 }
807
Open(const std::shared_ptr<io::RandomAccessFile> & file,int64_t footer_offset,const IpcReadOptions & options)808 Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
809 const IpcReadOptions& options) {
810 owned_file_ = file;
811 return Open(file.get(), footer_offset, options);
812 }
813
Open(io::RandomAccessFile * file,int64_t footer_offset,const IpcReadOptions & options)814 Status Open(io::RandomAccessFile* file, int64_t footer_offset,
815 const IpcReadOptions& options) {
816 file_ = file;
817 options_ = options;
818 footer_offset_ = footer_offset;
819 RETURN_NOT_OK(ReadFooter());
820
821 // Get the schema and record any observed dictionaries
822 return UnpackSchemaMessage(footer_->schema(), options, &dictionary_memo_, &schema_,
823 &out_schema_, &field_inclusion_mask_);
824 }
825
schema() const826 std::shared_ptr<Schema> schema() const override { return out_schema_; }
827
metadata() const828 std::shared_ptr<const KeyValueMetadata> metadata() const override { return metadata_; }
829
830 private:
GetRecordBatchBlock(int i) const831 FileBlock GetRecordBatchBlock(int i) const {
832 return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i));
833 }
834
GetDictionaryBlock(int i) const835 FileBlock GetDictionaryBlock(int i) const {
836 return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i));
837 }
838
ReadMessageFromBlock(const FileBlock & block,std::unique_ptr<Message> * out)839 Status ReadMessageFromBlock(const FileBlock& block, std::unique_ptr<Message>* out) {
840 if (!BitUtil::IsMultipleOf8(block.offset) ||
841 !BitUtil::IsMultipleOf8(block.metadata_length) ||
842 !BitUtil::IsMultipleOf8(block.body_length)) {
843 return Status::Invalid("Unaligned block in IPC file");
844 }
845
846 // TODO(wesm): this breaks integration tests, see ARROW-3256
847 // DCHECK_EQ((*out)->body_length(), block.body_length);
848
849 return ReadMessage(block.offset, block.metadata_length, file_).Value(out);
850 }
851
ReadDictionaries()852 Status ReadDictionaries() {
853 // Read all the dictionaries
854 for (int i = 0; i < num_dictionaries(); ++i) {
855 std::unique_ptr<Message> message;
856 RETURN_NOT_OK(ReadMessageFromBlock(GetDictionaryBlock(i), &message));
857
858 CHECK_HAS_BODY(*message);
859 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
860 RETURN_NOT_OK(ReadDictionary(*message->metadata(), &dictionary_memo_, options_,
861 reader.get()));
862 }
863 return Status::OK();
864 }
865
ReadFooter()866 Status ReadFooter() {
867 const int32_t magic_size = static_cast<int>(strlen(kArrowMagicBytes));
868
869 if (footer_offset_ <= magic_size * 2 + 4) {
870 return Status::Invalid("File is too small: ", footer_offset_);
871 }
872
873 int file_end_size = static_cast<int>(magic_size + sizeof(int32_t));
874 ARROW_ASSIGN_OR_RAISE(auto buffer,
875 file_->ReadAt(footer_offset_ - file_end_size, file_end_size));
876
877 const int64_t expected_footer_size = magic_size + sizeof(int32_t);
878 if (buffer->size() < expected_footer_size) {
879 return Status::Invalid("Unable to read ", expected_footer_size, "from end of file");
880 }
881
882 if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) {
883 return Status::Invalid("Not an Arrow file");
884 }
885
886 int32_t footer_length = *reinterpret_cast<const int32_t*>(buffer->data());
887
888 if (footer_length <= 0 || footer_length > footer_offset_ - magic_size * 2 - 4) {
889 return Status::Invalid("File is smaller than indicated metadata size");
890 }
891
892 // Now read the footer
893 ARROW_ASSIGN_OR_RAISE(
894 footer_buffer_,
895 file_->ReadAt(footer_offset_ - footer_length - file_end_size, footer_length));
896
897 auto data = footer_buffer_->data();
898 flatbuffers::Verifier verifier(data, footer_buffer_->size(), 128);
899 if (!flatbuf::VerifyFooterBuffer(verifier)) {
900 return Status::IOError("Verification of flatbuffer-encoded Footer failed.");
901 }
902 footer_ = flatbuf::GetFooter(data);
903
904 auto fb_metadata = footer_->custom_metadata();
905 if (fb_metadata != nullptr) {
906 std::shared_ptr<KeyValueMetadata> md;
907 RETURN_NOT_OK(internal::GetKeyValueMetadata(fb_metadata, &md));
908 metadata_ = std::move(md); // const-ify
909 }
910
911 return Status::OK();
912 }
913
num_dictionaries() const914 int num_dictionaries() const {
915 return static_cast<int>(internal::FlatBuffersVectorSize(footer_->dictionaries()));
916 }
917
918 io::RandomAccessFile* file_;
919 IpcReadOptions options_;
920 std::vector<bool> field_inclusion_mask_;
921
922 std::shared_ptr<io::RandomAccessFile> owned_file_;
923
924 // The location where the Arrow file layout ends. May be the end of the file
925 // or some other location if embedded in a larger file.
926 int64_t footer_offset_;
927
928 // Footer metadata
929 std::shared_ptr<Buffer> footer_buffer_;
930 const flatbuf::Footer* footer_;
931 std::shared_ptr<const KeyValueMetadata> metadata_;
932
933 bool read_dictionaries_ = false;
934 DictionaryMemo dictionary_memo_;
935
936 // Reconstructed schema, including any read dictionaries
937 std::shared_ptr<Schema> schema_;
938 // Schema with deselected fields dropped
939 std::shared_ptr<Schema> out_schema_;
940 };
941
Open(io::RandomAccessFile * file,const IpcReadOptions & options)942 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
943 io::RandomAccessFile* file, const IpcReadOptions& options) {
944 ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
945 return Open(file, footer_offset, options);
946 }
947
Open(io::RandomAccessFile * file,int64_t footer_offset,const IpcReadOptions & options)948 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
949 io::RandomAccessFile* file, int64_t footer_offset, const IpcReadOptions& options) {
950 auto result = std::make_shared<RecordBatchFileReaderImpl>();
951 RETURN_NOT_OK(result->Open(file, footer_offset, options));
952 return result;
953 }
954
Open(const std::shared_ptr<io::RandomAccessFile> & file,const IpcReadOptions & options)955 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
956 const std::shared_ptr<io::RandomAccessFile>& file, const IpcReadOptions& options) {
957 ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
958 return Open(file, footer_offset, options);
959 }
960
Open(const std::shared_ptr<io::RandomAccessFile> & file,int64_t footer_offset,const IpcReadOptions & options)961 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
962 const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
963 const IpcReadOptions& options) {
964 auto result = std::make_shared<RecordBatchFileReaderImpl>();
965 RETURN_NOT_OK(result->Open(file, footer_offset, options));
966 return result;
967 }
968
OnEOS()969 Status Listener::OnEOS() { return Status::OK(); }
970
OnSchemaDecoded(std::shared_ptr<Schema> schema)971 Status Listener::OnSchemaDecoded(std::shared_ptr<Schema> schema) { return Status::OK(); }
972
OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch)973 Status Listener::OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch) {
974 return Status::NotImplemented("OnRecordBatchDecoded() callback isn't implemented");
975 }
976
977 class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener {
978 private:
979 enum State {
980 SCHEMA,
981 INITIAL_DICTIONARIES,
982 RECORD_BATCHES,
983 EOS,
984 };
985
986 public:
StreamDecoderImpl(std::shared_ptr<Listener> listener,const IpcReadOptions & options)987 explicit StreamDecoderImpl(std::shared_ptr<Listener> listener,
988 const IpcReadOptions& options)
989 : MessageDecoderListener(),
990 listener_(std::move(listener)),
991 options_(options),
992 state_(State::SCHEMA),
993 message_decoder_(std::shared_ptr<StreamDecoderImpl>(this, [](void*) {}),
994 options_.memory_pool),
995 field_inclusion_mask_(),
996 n_required_dictionaries_(0),
997 dictionary_memo_(),
998 schema_() {}
999
OnMessageDecoded(std::unique_ptr<Message> message)1000 Status OnMessageDecoded(std::unique_ptr<Message> message) override {
1001 switch (state_) {
1002 case State::SCHEMA:
1003 ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message)));
1004 break;
1005 case State::INITIAL_DICTIONARIES:
1006 ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message)));
1007 break;
1008 case State::RECORD_BATCHES:
1009 ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message)));
1010 break;
1011 case State::EOS:
1012 break;
1013 }
1014 return Status::OK();
1015 }
1016
OnEOS()1017 Status OnEOS() override {
1018 state_ = State::EOS;
1019 return listener_->OnEOS();
1020 }
1021
Consume(const uint8_t * data,int64_t size)1022 Status Consume(const uint8_t* data, int64_t size) {
1023 return message_decoder_.Consume(data, size);
1024 }
1025
Consume(std::shared_ptr<Buffer> buffer)1026 Status Consume(std::shared_ptr<Buffer> buffer) {
1027 return message_decoder_.Consume(std::move(buffer));
1028 }
1029
schema() const1030 std::shared_ptr<Schema> schema() const { return out_schema_; }
1031
next_required_size() const1032 int64_t next_required_size() const { return message_decoder_.next_required_size(); }
1033
1034 private:
OnSchemaMessageDecoded(std::unique_ptr<Message> message)1035 Status OnSchemaMessageDecoded(std::unique_ptr<Message> message) {
1036 RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, &dictionary_memo_, &schema_,
1037 &out_schema_, &field_inclusion_mask_));
1038
1039 n_required_dictionaries_ = dictionary_memo_.num_fields();
1040 if (n_required_dictionaries_ == 0) {
1041 state_ = State::RECORD_BATCHES;
1042 RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
1043 } else {
1044 state_ = State::INITIAL_DICTIONARIES;
1045 }
1046 return Status::OK();
1047 }
1048
OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message)1049 Status OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message) {
1050 if (message->type() != Message::DICTIONARY_BATCH) {
1051 return Status::Invalid("IPC stream did not have the expected number (",
1052 dictionary_memo_.num_fields(),
1053 ") of dictionaries at the start of the stream");
1054 }
1055 RETURN_NOT_OK(ParseDictionary(*message, &dictionary_memo_, options_));
1056 n_required_dictionaries_--;
1057 if (n_required_dictionaries_ == 0) {
1058 state_ = State::RECORD_BATCHES;
1059 ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
1060 }
1061 return Status::OK();
1062 }
1063
OnRecordBatchMessageDecoded(std::unique_ptr<Message> message)1064 Status OnRecordBatchMessageDecoded(std::unique_ptr<Message> message) {
1065 if (message->type() == Message::DICTIONARY_BATCH) {
1066 return UpdateDictionaries(*message, &dictionary_memo_, options_);
1067 } else {
1068 CHECK_HAS_BODY(*message);
1069 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
1070 ARROW_ASSIGN_OR_RAISE(
1071 auto batch,
1072 ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
1073 &dictionary_memo_, options_, reader.get()));
1074 return listener_->OnRecordBatchDecoded(std::move(batch));
1075 }
1076 }
1077
1078 std::shared_ptr<Listener> listener_;
1079 IpcReadOptions options_;
1080 State state_;
1081 MessageDecoder message_decoder_;
1082 std::vector<bool> field_inclusion_mask_;
1083 int n_required_dictionaries_;
1084 DictionaryMemo dictionary_memo_;
1085 std::shared_ptr<Schema> schema_, out_schema_;
1086 };
1087
StreamDecoder(std::shared_ptr<Listener> listener,const IpcReadOptions & options)1088 StreamDecoder::StreamDecoder(std::shared_ptr<Listener> listener,
1089 const IpcReadOptions& options) {
1090 impl_.reset(new StreamDecoderImpl(std::move(listener), options));
1091 }
1092
~StreamDecoder()1093 StreamDecoder::~StreamDecoder() {}
1094
Consume(const uint8_t * data,int64_t size)1095 Status StreamDecoder::Consume(const uint8_t* data, int64_t size) {
1096 return impl_->Consume(data, size);
1097 }
Consume(std::shared_ptr<Buffer> buffer)1098 Status StreamDecoder::Consume(std::shared_ptr<Buffer> buffer) {
1099 return impl_->Consume(std::move(buffer));
1100 }
1101
schema() const1102 std::shared_ptr<Schema> StreamDecoder::schema() const { return impl_->schema(); }
1103
next_required_size() const1104 int64_t StreamDecoder::next_required_size() const { return impl_->next_required_size(); }
1105
ReadSchema(io::InputStream * stream,DictionaryMemo * dictionary_memo)1106 Result<std::shared_ptr<Schema>> ReadSchema(io::InputStream* stream,
1107 DictionaryMemo* dictionary_memo) {
1108 std::unique_ptr<MessageReader> reader = MessageReader::Open(stream);
1109 ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, reader->ReadNextMessage());
1110 if (!message) {
1111 return Status::Invalid("Tried reading schema message, was null or length 0");
1112 }
1113 CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type());
1114 return ReadSchema(*message, dictionary_memo);
1115 }
1116
ReadSchema(const Message & message,DictionaryMemo * dictionary_memo)1117 Result<std::shared_ptr<Schema>> ReadSchema(const Message& message,
1118 DictionaryMemo* dictionary_memo) {
1119 std::shared_ptr<Schema> result;
1120 RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, &result));
1121 return result;
1122 }
1123
ReadTensor(io::InputStream * file)1124 Result<std::shared_ptr<Tensor>> ReadTensor(io::InputStream* file) {
1125 std::unique_ptr<Message> message;
1126 RETURN_NOT_OK(ReadContiguousPayload(file, &message));
1127 return ReadTensor(*message);
1128 }
1129
ReadTensor(const Message & message)1130 Result<std::shared_ptr<Tensor>> ReadTensor(const Message& message) {
1131 std::shared_ptr<DataType> type;
1132 std::vector<int64_t> shape;
1133 std::vector<int64_t> strides;
1134 std::vector<std::string> dim_names;
1135 CHECK_HAS_BODY(message);
1136 RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides,
1137 &dim_names));
1138 return Tensor::Make(type, message.body(), shape, strides, dim_names);
1139 }
1140
1141 namespace {
1142
ReadSparseCOOIndex(const flatbuf::SparseTensor * sparse_tensor,const std::vector<int64_t> & shape,int64_t non_zero_length,io::RandomAccessFile * file)1143 Result<std::shared_ptr<SparseIndex>> ReadSparseCOOIndex(
1144 const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
1145 int64_t non_zero_length, io::RandomAccessFile* file) {
1146 auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCOO();
1147 const auto ndim = static_cast<int64_t>(shape.size());
1148
1149 std::shared_ptr<DataType> indices_type;
1150 RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(sparse_index, &indices_type));
1151 const int64_t indices_elsize =
1152 checked_cast<const IntegerType&>(*indices_type).bit_width() / 8;
1153
1154 auto* indices_buffer = sparse_index->indicesBuffer();
1155 ARROW_ASSIGN_OR_RAISE(auto indices_data,
1156 file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
1157 std::vector<int64_t> indices_shape({non_zero_length, ndim});
1158 auto* indices_strides = sparse_index->indicesStrides();
1159 std::vector<int64_t> strides(2);
1160 if (indices_strides && indices_strides->size() > 0) {
1161 if (indices_strides->size() != 2) {
1162 return Status::Invalid("Wrong size for indicesStrides in SparseCOOIndex");
1163 }
1164 strides[0] = indices_strides->Get(0);
1165 strides[1] = indices_strides->Get(1);
1166 } else {
1167 // Row-major by default
1168 strides[0] = indices_elsize * ndim;
1169 strides[1] = indices_elsize;
1170 }
1171 return std::make_shared<SparseCOOIndex>(
1172 std::make_shared<Tensor>(indices_type, indices_data, indices_shape, strides));
1173 }
1174
ReadSparseCSXIndex(const flatbuf::SparseTensor * sparse_tensor,const std::vector<int64_t> & shape,int64_t non_zero_length,io::RandomAccessFile * file)1175 Result<std::shared_ptr<SparseIndex>> ReadSparseCSXIndex(
1176 const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
1177 int64_t non_zero_length, io::RandomAccessFile* file) {
1178 if (shape.size() != 2) {
1179 return Status::Invalid("Invalid shape length for a sparse matrix");
1180 }
1181
1182 auto* sparse_index = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX();
1183
1184 std::shared_ptr<DataType> indptr_type, indices_type;
1185 RETURN_NOT_OK(
1186 internal::GetSparseCSXIndexMetadata(sparse_index, &indptr_type, &indices_type));
1187
1188 auto* indptr_buffer = sparse_index->indptrBuffer();
1189 ARROW_ASSIGN_OR_RAISE(auto indptr_data,
1190 file->ReadAt(indptr_buffer->offset(), indptr_buffer->length()));
1191
1192 auto* indices_buffer = sparse_index->indicesBuffer();
1193 ARROW_ASSIGN_OR_RAISE(auto indices_data,
1194 file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
1195
1196 std::vector<int64_t> indices_shape({non_zero_length});
1197 const auto indices_minimum_bytes =
1198 indices_shape[0] * checked_pointer_cast<FixedWidthType>(indices_type)->bit_width() /
1199 CHAR_BIT;
1200 if (indices_minimum_bytes > indices_buffer->length()) {
1201 return Status::Invalid("shape is inconsistent to the size of indices buffer");
1202 }
1203
1204 switch (sparse_index->compressedAxis()) {
1205 case flatbuf::SparseMatrixCompressedAxis::Row: {
1206 std::vector<int64_t> indptr_shape({shape[0] + 1});
1207 const int64_t indptr_minimum_bytes =
1208 indptr_shape[0] *
1209 checked_pointer_cast<FixedWidthType>(indptr_type)->bit_width() / CHAR_BIT;
1210 if (indptr_minimum_bytes > indptr_buffer->length()) {
1211 return Status::Invalid("shape is inconsistent to the size of indptr buffer");
1212 }
1213 return std::make_shared<SparseCSRIndex>(
1214 std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
1215 std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
1216 }
1217 case flatbuf::SparseMatrixCompressedAxis::Column: {
1218 std::vector<int64_t> indptr_shape({shape[1] + 1});
1219 const int64_t indptr_minimum_bytes =
1220 indptr_shape[0] *
1221 checked_pointer_cast<FixedWidthType>(indptr_type)->bit_width() / CHAR_BIT;
1222 if (indptr_minimum_bytes > indptr_buffer->length()) {
1223 return Status::Invalid("shape is inconsistent to the size of indptr buffer");
1224 }
1225 return std::make_shared<SparseCSCIndex>(
1226 std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
1227 std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
1228 }
1229 default:
1230 return Status::Invalid("Invalid value of SparseMatrixCompressedAxis");
1231 }
1232 }
1233
ReadSparseCSFIndex(const flatbuf::SparseTensor * sparse_tensor,const std::vector<int64_t> & shape,io::RandomAccessFile * file)1234 Result<std::shared_ptr<SparseIndex>> ReadSparseCSFIndex(
1235 const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
1236 io::RandomAccessFile* file) {
1237 auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCSF();
1238 const auto ndim = static_cast<int64_t>(shape.size());
1239 auto* indptr_buffers = sparse_index->indptrBuffers();
1240 auto* indices_buffers = sparse_index->indicesBuffers();
1241 std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
1242 std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
1243
1244 std::shared_ptr<DataType> indptr_type, indices_type;
1245 std::vector<int64_t> axis_order, indices_size;
1246
1247 RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
1248 sparse_index, &axis_order, &indices_size, &indptr_type, &indices_type));
1249 for (int i = 0; i < static_cast<int>(indptr_buffers->Length()); ++i) {
1250 ARROW_ASSIGN_OR_RAISE(indptr_data[i], file->ReadAt(indptr_buffers->Get(i)->offset(),
1251 indptr_buffers->Get(i)->length()));
1252 }
1253 for (int i = 0; i < static_cast<int>(indices_buffers->Length()); ++i) {
1254 ARROW_ASSIGN_OR_RAISE(indices_data[i],
1255 file->ReadAt(indices_buffers->Get(i)->offset(),
1256 indices_buffers->Get(i)->length()));
1257 }
1258
1259 return SparseCSFIndex::Make(indptr_type, indices_type, indices_size, axis_order,
1260 indptr_data, indices_data);
1261 }
1262
MakeSparseTensorWithSparseCOOIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCOOIndex> & sparse_index,int64_t non_zero_length,const std::shared_ptr<Buffer> & data)1263 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCOOIndex(
1264 const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1265 const std::vector<std::string>& dim_names,
1266 const std::shared_ptr<SparseCOOIndex>& sparse_index, int64_t non_zero_length,
1267 const std::shared_ptr<Buffer>& data) {
1268 return SparseCOOTensor::Make(sparse_index, type, data, shape, dim_names);
1269 }
1270
MakeSparseTensorWithSparseCSRIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCSRIndex> & sparse_index,int64_t non_zero_length,const std::shared_ptr<Buffer> & data)1271 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSRIndex(
1272 const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1273 const std::vector<std::string>& dim_names,
1274 const std::shared_ptr<SparseCSRIndex>& sparse_index, int64_t non_zero_length,
1275 const std::shared_ptr<Buffer>& data) {
1276 return SparseCSRMatrix::Make(sparse_index, type, data, shape, dim_names);
1277 }
1278
MakeSparseTensorWithSparseCSCIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCSCIndex> & sparse_index,int64_t non_zero_length,const std::shared_ptr<Buffer> & data)1279 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSCIndex(
1280 const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1281 const std::vector<std::string>& dim_names,
1282 const std::shared_ptr<SparseCSCIndex>& sparse_index, int64_t non_zero_length,
1283 const std::shared_ptr<Buffer>& data) {
1284 return SparseCSCMatrix::Make(sparse_index, type, data, shape, dim_names);
1285 }
1286
MakeSparseTensorWithSparseCSFIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCSFIndex> & sparse_index,const std::shared_ptr<Buffer> & data)1287 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSFIndex(
1288 const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1289 const std::vector<std::string>& dim_names,
1290 const std::shared_ptr<SparseCSFIndex>& sparse_index,
1291 const std::shared_ptr<Buffer>& data) {
1292 return SparseCSFTensor::Make(sparse_index, type, data, shape, dim_names);
1293 }
1294
ReadSparseTensorMetadata(const Buffer & metadata,std::shared_ptr<DataType> * out_type,std::vector<int64_t> * out_shape,std::vector<std::string> * out_dim_names,int64_t * out_non_zero_length,SparseTensorFormat::type * out_format_id,const flatbuf::SparseTensor ** out_fb_sparse_tensor,const flatbuf::Buffer ** out_buffer)1295 Status ReadSparseTensorMetadata(const Buffer& metadata,
1296 std::shared_ptr<DataType>* out_type,
1297 std::vector<int64_t>* out_shape,
1298 std::vector<std::string>* out_dim_names,
1299 int64_t* out_non_zero_length,
1300 SparseTensorFormat::type* out_format_id,
1301 const flatbuf::SparseTensor** out_fb_sparse_tensor,
1302 const flatbuf::Buffer** out_buffer) {
1303 RETURN_NOT_OK(internal::GetSparseTensorMetadata(
1304 metadata, out_type, out_shape, out_dim_names, out_non_zero_length, out_format_id));
1305
1306 const flatbuf::Message* message;
1307 RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
1308
1309 auto sparse_tensor = message->header_as_SparseTensor();
1310 if (sparse_tensor == nullptr) {
1311 return Status::IOError(
1312 "Header-type of flatbuffer-encoded Message is not SparseTensor.");
1313 }
1314 *out_fb_sparse_tensor = sparse_tensor;
1315
1316 auto buffer = sparse_tensor->data();
1317 if (!BitUtil::IsMultipleOf8(buffer->offset())) {
1318 return Status::Invalid(
1319 "Buffer of sparse index data did not start on 8-byte aligned offset: ",
1320 buffer->offset());
1321 }
1322 *out_buffer = buffer;
1323
1324 return Status::OK();
1325 }
1326
1327 } // namespace
1328
1329 namespace internal {
1330
1331 namespace {
1332
GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,const size_t ndim)1333 Result<size_t> GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,
1334 const size_t ndim) {
1335 switch (format_id) {
1336 case SparseTensorFormat::COO:
1337 return 2;
1338
1339 case SparseTensorFormat::CSR:
1340 return 3;
1341
1342 case SparseTensorFormat::CSC:
1343 return 3;
1344
1345 case SparseTensorFormat::CSF:
1346 return 2 * ndim;
1347
1348 default:
1349 return Status::Invalid("Unrecognized sparse tensor format");
1350 }
1351 }
1352
CheckSparseTensorBodyBufferCount(const IpcPayload & payload,SparseTensorFormat::type sparse_tensor_format_id,const size_t ndim)1353 Status CheckSparseTensorBodyBufferCount(const IpcPayload& payload,
1354 SparseTensorFormat::type sparse_tensor_format_id,
1355 const size_t ndim) {
1356 size_t expected_body_buffer_count = 0;
1357 ARROW_ASSIGN_OR_RAISE(expected_body_buffer_count,
1358 GetSparseTensorBodyBufferCount(sparse_tensor_format_id, ndim));
1359 if (payload.body_buffers.size() != expected_body_buffer_count) {
1360 return Status::Invalid("Invalid body buffer count for a sparse tensor");
1361 }
1362
1363 return Status::OK();
1364 }
1365
1366 } // namespace
1367
ReadSparseTensorBodyBufferCount(const Buffer & metadata)1368 Result<size_t> ReadSparseTensorBodyBufferCount(const Buffer& metadata) {
1369 SparseTensorFormat::type format_id;
1370 std::vector<int64_t> shape;
1371
1372 RETURN_NOT_OK(internal::GetSparseTensorMetadata(metadata, nullptr, &shape, nullptr,
1373 nullptr, &format_id));
1374
1375 return GetSparseTensorBodyBufferCount(format_id, static_cast<size_t>(shape.size()));
1376 }
1377
ReadSparseTensorPayload(const IpcPayload & payload)1378 Result<std::shared_ptr<SparseTensor>> ReadSparseTensorPayload(const IpcPayload& payload) {
1379 std::shared_ptr<DataType> type;
1380 std::vector<int64_t> shape;
1381 std::vector<std::string> dim_names;
1382 int64_t non_zero_length;
1383 SparseTensorFormat::type sparse_tensor_format_id;
1384 const flatbuf::SparseTensor* sparse_tensor;
1385 const flatbuf::Buffer* buffer;
1386
1387 RETURN_NOT_OK(ReadSparseTensorMetadata(*payload.metadata, &type, &shape, &dim_names,
1388 &non_zero_length, &sparse_tensor_format_id,
1389 &sparse_tensor, &buffer));
1390
1391 RETURN_NOT_OK(CheckSparseTensorBodyBufferCount(payload, sparse_tensor_format_id,
1392 static_cast<size_t>(shape.size())));
1393
1394 switch (sparse_tensor_format_id) {
1395 case SparseTensorFormat::COO: {
1396 std::shared_ptr<SparseCOOIndex> sparse_index;
1397 std::shared_ptr<DataType> indices_type;
1398 RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(
1399 sparse_tensor->sparseIndex_as_SparseTensorIndexCOO(), &indices_type));
1400 ARROW_ASSIGN_OR_RAISE(sparse_index,
1401 SparseCOOIndex::Make(indices_type, shape, non_zero_length,
1402 payload.body_buffers[0]));
1403 return MakeSparseTensorWithSparseCOOIndex(type, shape, dim_names, sparse_index,
1404 non_zero_length, payload.body_buffers[1]);
1405 }
1406 case SparseTensorFormat::CSR: {
1407 std::shared_ptr<SparseCSRIndex> sparse_index;
1408 std::shared_ptr<DataType> indptr_type;
1409 std::shared_ptr<DataType> indices_type;
1410 RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
1411 sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
1412 &indices_type));
1413 ARROW_CHECK_EQ(indptr_type, indices_type);
1414 ARROW_ASSIGN_OR_RAISE(
1415 sparse_index,
1416 SparseCSRIndex::Make(indices_type, shape, non_zero_length,
1417 payload.body_buffers[0], payload.body_buffers[1]));
1418 return MakeSparseTensorWithSparseCSRIndex(type, shape, dim_names, sparse_index,
1419 non_zero_length, payload.body_buffers[2]);
1420 }
1421 case SparseTensorFormat::CSC: {
1422 std::shared_ptr<SparseCSCIndex> sparse_index;
1423 std::shared_ptr<DataType> indptr_type;
1424 std::shared_ptr<DataType> indices_type;
1425 RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
1426 sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
1427 &indices_type));
1428 ARROW_CHECK_EQ(indptr_type, indices_type);
1429 ARROW_ASSIGN_OR_RAISE(
1430 sparse_index,
1431 SparseCSCIndex::Make(indices_type, shape, non_zero_length,
1432 payload.body_buffers[0], payload.body_buffers[1]));
1433 return MakeSparseTensorWithSparseCSCIndex(type, shape, dim_names, sparse_index,
1434 non_zero_length, payload.body_buffers[2]);
1435 }
1436 case SparseTensorFormat::CSF: {
1437 std::shared_ptr<SparseCSFIndex> sparse_index;
1438 std::shared_ptr<DataType> indptr_type, indices_type;
1439 std::vector<int64_t> axis_order, indices_size;
1440
1441 RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
1442 sparse_tensor->sparseIndex_as_SparseTensorIndexCSF(), &axis_order,
1443 &indices_size, &indptr_type, &indices_type));
1444 ARROW_CHECK_EQ(indptr_type, indices_type);
1445
1446 const int64_t ndim = shape.size();
1447 std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
1448 std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
1449
1450 for (int64_t i = 0; i < ndim - 1; ++i) {
1451 indptr_data[i] = payload.body_buffers[i];
1452 }
1453 for (int64_t i = 0; i < ndim; ++i) {
1454 indices_data[i] = payload.body_buffers[i + ndim - 1];
1455 }
1456
1457 ARROW_ASSIGN_OR_RAISE(sparse_index,
1458 SparseCSFIndex::Make(indptr_type, indices_type, indices_size,
1459 axis_order, indptr_data, indices_data));
1460 return MakeSparseTensorWithSparseCSFIndex(type, shape, dim_names, sparse_index,
1461 payload.body_buffers[2 * ndim - 1]);
1462 }
1463 default:
1464 return Status::Invalid("Unsupported sparse index format");
1465 }
1466 }
1467
1468 } // namespace internal
1469
ReadSparseTensor(const Buffer & metadata,io::RandomAccessFile * file)1470 Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Buffer& metadata,
1471 io::RandomAccessFile* file) {
1472 std::shared_ptr<DataType> type;
1473 std::vector<int64_t> shape;
1474 std::vector<std::string> dim_names;
1475 int64_t non_zero_length;
1476 SparseTensorFormat::type sparse_tensor_format_id;
1477 const flatbuf::SparseTensor* sparse_tensor;
1478 const flatbuf::Buffer* buffer;
1479
1480 RETURN_NOT_OK(ReadSparseTensorMetadata(metadata, &type, &shape, &dim_names,
1481 &non_zero_length, &sparse_tensor_format_id,
1482 &sparse_tensor, &buffer));
1483
1484 ARROW_ASSIGN_OR_RAISE(auto data, file->ReadAt(buffer->offset(), buffer->length()));
1485
1486 std::shared_ptr<SparseIndex> sparse_index;
1487 switch (sparse_tensor_format_id) {
1488 case SparseTensorFormat::COO: {
1489 ARROW_ASSIGN_OR_RAISE(
1490 sparse_index, ReadSparseCOOIndex(sparse_tensor, shape, non_zero_length, file));
1491 return MakeSparseTensorWithSparseCOOIndex(
1492 type, shape, dim_names, checked_pointer_cast<SparseCOOIndex>(sparse_index),
1493 non_zero_length, data);
1494 }
1495 case SparseTensorFormat::CSR: {
1496 ARROW_ASSIGN_OR_RAISE(
1497 sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
1498 return MakeSparseTensorWithSparseCSRIndex(
1499 type, shape, dim_names, checked_pointer_cast<SparseCSRIndex>(sparse_index),
1500 non_zero_length, data);
1501 }
1502 case SparseTensorFormat::CSC: {
1503 ARROW_ASSIGN_OR_RAISE(
1504 sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
1505 return MakeSparseTensorWithSparseCSCIndex(
1506 type, shape, dim_names, checked_pointer_cast<SparseCSCIndex>(sparse_index),
1507 non_zero_length, data);
1508 }
1509 case SparseTensorFormat::CSF: {
1510 ARROW_ASSIGN_OR_RAISE(sparse_index, ReadSparseCSFIndex(sparse_tensor, shape, file));
1511 return MakeSparseTensorWithSparseCSFIndex(
1512 type, shape, dim_names, checked_pointer_cast<SparseCSFIndex>(sparse_index),
1513 data);
1514 }
1515 default:
1516 return Status::Invalid("Unsupported sparse index format");
1517 }
1518 }
1519
ReadSparseTensor(const Message & message)1520 Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Message& message) {
1521 CHECK_HAS_BODY(message);
1522 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
1523 return ReadSparseTensor(*message.metadata(), reader.get());
1524 }
1525
ReadSparseTensor(io::InputStream * file)1526 Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(io::InputStream* file) {
1527 std::unique_ptr<Message> message;
1528 RETURN_NOT_OK(ReadContiguousPayload(file, &message));
1529 CHECK_MESSAGE_TYPE(Message::SPARSE_TENSOR, message->type());
1530 CHECK_HAS_BODY(*message);
1531 ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
1532 return ReadSparseTensor(*message->metadata(), reader.get());
1533 }
1534
1535 ///////////////////////////////////////////////////////////////////////////
1536 // Helpers for fuzzing
1537
1538 namespace internal {
1539
FuzzIpcStream(const uint8_t * data,int64_t size)1540 Status FuzzIpcStream(const uint8_t* data, int64_t size) {
1541 auto buffer = std::make_shared<Buffer>(data, size);
1542 io::BufferReader buffer_reader(buffer);
1543
1544 std::shared_ptr<RecordBatchReader> batch_reader;
1545 ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchStreamReader::Open(&buffer_reader));
1546
1547 while (true) {
1548 std::shared_ptr<arrow::RecordBatch> batch;
1549 RETURN_NOT_OK(batch_reader->ReadNext(&batch));
1550 if (batch == nullptr) {
1551 break;
1552 }
1553 RETURN_NOT_OK(batch->ValidateFull());
1554 }
1555
1556 return Status::OK();
1557 }
1558
FuzzIpcFile(const uint8_t * data,int64_t size)1559 Status FuzzIpcFile(const uint8_t* data, int64_t size) {
1560 auto buffer = std::make_shared<Buffer>(data, size);
1561 io::BufferReader buffer_reader(buffer);
1562
1563 std::shared_ptr<RecordBatchFileReader> batch_reader;
1564 ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchFileReader::Open(&buffer_reader));
1565
1566 const int n_batches = batch_reader->num_record_batches();
1567 for (int i = 0; i < n_batches; ++i) {
1568 ARROW_ASSIGN_OR_RAISE(auto batch, batch_reader->ReadRecordBatch(i));
1569 RETURN_NOT_OK(batch->ValidateFull());
1570 }
1571
1572 return Status::OK();
1573 }
1574
1575 } // namespace internal
1576
1577 // ----------------------------------------------------------------------
1578 // Deprecated functions
1579
Open(std::unique_ptr<MessageReader> message_reader,std::shared_ptr<RecordBatchReader> * out)1580 Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_reader,
1581 std::shared_ptr<RecordBatchReader>* out) {
1582 return Open(std::move(message_reader), IpcReadOptions::Defaults()).Value(out);
1583 }
1584
Open(std::unique_ptr<MessageReader> message_reader,std::unique_ptr<RecordBatchReader> * out)1585 Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_reader,
1586 std::unique_ptr<RecordBatchReader>* out) {
1587 auto result =
1588 std::unique_ptr<RecordBatchStreamReaderImpl>(new RecordBatchStreamReaderImpl());
1589 RETURN_NOT_OK(result->Open(std::move(message_reader), IpcReadOptions::Defaults()));
1590 *out = std::move(result);
1591 return Status::OK();
1592 }
1593
Open(io::InputStream * stream,std::shared_ptr<RecordBatchReader> * out)1594 Status RecordBatchStreamReader::Open(io::InputStream* stream,
1595 std::shared_ptr<RecordBatchReader>* out) {
1596 return Open(MessageReader::Open(stream)).Value(out);
1597 }
1598
Open(const std::shared_ptr<io::InputStream> & stream,std::shared_ptr<RecordBatchReader> * out)1599 Status RecordBatchStreamReader::Open(const std::shared_ptr<io::InputStream>& stream,
1600 std::shared_ptr<RecordBatchReader>* out) {
1601 return Open(MessageReader::Open(stream)).Value(out);
1602 }
1603
Open(io::RandomAccessFile * file,std::shared_ptr<RecordBatchFileReader> * out)1604 Status RecordBatchFileReader::Open(io::RandomAccessFile* file,
1605 std::shared_ptr<RecordBatchFileReader>* out) {
1606 return Open(file).Value(out);
1607 }
1608
Open(io::RandomAccessFile * file,int64_t footer_offset,std::shared_ptr<RecordBatchFileReader> * out)1609 Status RecordBatchFileReader::Open(io::RandomAccessFile* file, int64_t footer_offset,
1610 std::shared_ptr<RecordBatchFileReader>* out) {
1611 return Open(file, footer_offset).Value(out);
1612 }
1613
Open(const std::shared_ptr<io::RandomAccessFile> & file,std::shared_ptr<RecordBatchFileReader> * out)1614 Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file,
1615 std::shared_ptr<RecordBatchFileReader>* out) {
1616 return Open(file).Value(out);
1617 }
1618
Open(const std::shared_ptr<io::RandomAccessFile> & file,int64_t footer_offset,std::shared_ptr<RecordBatchFileReader> * out)1619 Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file,
1620 int64_t footer_offset,
1621 std::shared_ptr<RecordBatchFileReader>* out) {
1622 return Open(file, footer_offset).Value(out);
1623 }
1624
ReadSchema(io::InputStream * stream,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * out)1625 Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo,
1626 std::shared_ptr<Schema>* out) {
1627 return ReadSchema(stream, dictionary_memo).Value(out);
1628 }
1629
ReadSchema(const Message & message,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * out)1630 Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo,
1631 std::shared_ptr<Schema>* out) {
1632 return ReadSchema(message, dictionary_memo).Value(out);
1633 }
1634
ReadRecordBatch(const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,io::InputStream * stream,std::shared_ptr<RecordBatch> * out)1635 Status ReadRecordBatch(const std::shared_ptr<Schema>& schema,
1636 const DictionaryMemo* dictionary_memo, io::InputStream* stream,
1637 std::shared_ptr<RecordBatch>* out) {
1638 return ReadRecordBatch(schema, dictionary_memo, IpcReadOptions::Defaults(), stream)
1639 .Value(out);
1640 }
1641
ReadRecordBatch(const Message & message,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,std::shared_ptr<RecordBatch> * out)1642 Status ReadRecordBatch(const Message& message, const std::shared_ptr<Schema>& schema,
1643 const DictionaryMemo* dictionary_memo,
1644 std::shared_ptr<RecordBatch>* out) {
1645 return ReadRecordBatch(message, schema, dictionary_memo, IpcReadOptions::Defaults())
1646 .Value(out);
1647 }
1648
ReadRecordBatch(const Buffer & metadata,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,io::RandomAccessFile * file,std::shared_ptr<RecordBatch> * out)1649 Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr<Schema>& schema,
1650 const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file,
1651 std::shared_ptr<RecordBatch>* out) {
1652 return ReadRecordBatch(metadata, schema, dictionary_memo, IpcReadOptions::Defaults(),
1653 file)
1654 .Value(out);
1655 }
1656
1657 } // namespace ipc
1658 } // namespace arrow
1659