1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/ipc/reader.h"
19 
20 #include <algorithm>
21 #include <climits>
22 #include <cstdint>
23 #include <cstring>
24 #include <string>
25 #include <type_traits>
26 #include <utility>
27 #include <vector>
28 
29 #include <flatbuffers/flatbuffers.h>  // IWYU pragma: export
30 
31 #include "arrow/array.h"
32 #include "arrow/buffer.h"
33 #include "arrow/extension_type.h"
34 #include "arrow/io/interfaces.h"
35 #include "arrow/io/memory.h"
36 #include "arrow/ipc/message.h"
37 #include "arrow/ipc/metadata_internal.h"
38 #include "arrow/ipc/writer.h"
39 #include "arrow/record_batch.h"
40 #include "arrow/sparse_tensor.h"
41 #include "arrow/status.h"
42 #include "arrow/type.h"
43 #include "arrow/type_traits.h"
44 #include "arrow/util/bit_util.h"
45 #include "arrow/util/checked_cast.h"
46 #include "arrow/util/compression.h"
47 #include "arrow/util/key_value_metadata.h"
48 #include "arrow/util/logging.h"
49 #include "arrow/util/parallel.h"
50 #include "arrow/util/ubsan.h"
51 #include "arrow/visitor_inline.h"
52 
53 #include "generated/File_generated.h"  // IWYU pragma: export
54 #include "generated/Message_generated.h"
55 #include "generated/Schema_generated.h"
56 #include "generated/SparseTensor_generated.h"
57 
58 namespace arrow {
59 
60 namespace flatbuf = org::apache::arrow::flatbuf;
61 
62 using internal::checked_cast;
63 using internal::checked_pointer_cast;
64 
65 namespace ipc {
66 
67 using internal::FileBlock;
68 using internal::kArrowMagicBytes;
69 
70 namespace {
71 
InvalidMessageType(Message::Type expected,Message::Type actual)72 Status InvalidMessageType(Message::Type expected, Message::Type actual) {
73   return Status::IOError("Expected IPC message of type ", FormatMessageType(expected),
74                          " but got ", FormatMessageType(actual));
75 }
76 
77 #define CHECK_MESSAGE_TYPE(expected, actual)           \
78   do {                                                 \
79     if ((actual) != (expected)) {                      \
80       return InvalidMessageType((expected), (actual)); \
81     }                                                  \
82   } while (0)
83 
84 #define CHECK_HAS_BODY(message)                                       \
85   do {                                                                \
86     if ((message).body() == nullptr) {                                \
87       return Status::IOError("Expected body in IPC message of type ", \
88                              FormatMessageType((message).type()));    \
89     }                                                                 \
90   } while (0)
91 
92 #define CHECK_HAS_NO_BODY(message)                                      \
93   do {                                                                  \
94     if ((message).body_length() != 0) {                                 \
95       return Status::IOError("Unexpected body in IPC message of type ", \
96                              FormatMessageType((message).type()));      \
97     }                                                                   \
98   } while (0)
99 
100 }  // namespace
101 
102 // ----------------------------------------------------------------------
103 // Record batch read path
104 
105 /// The field_index and buffer_index are incremented based on how much of the
106 /// batch is "consumed" (through nested data reconstruction, for example)
107 class ArrayLoader {
108  public:
ArrayLoader(const flatbuf::RecordBatch * metadata,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)109   explicit ArrayLoader(const flatbuf::RecordBatch* metadata,
110                        const DictionaryMemo* dictionary_memo,
111                        const IpcReadOptions& options, io::RandomAccessFile* file)
112       : metadata_(metadata),
113         file_(file),
114         dictionary_memo_(dictionary_memo),
115         max_recursion_depth_(options.max_recursion_depth) {}
116 
ReadBuffer(int64_t offset,int64_t length,std::shared_ptr<Buffer> * out)117   Status ReadBuffer(int64_t offset, int64_t length, std::shared_ptr<Buffer>* out) {
118     if (skip_io_) {
119       return Status::OK();
120     }
121     // This construct permits overriding GetBuffer at compile time
122     if (!BitUtil::IsMultipleOf8(offset)) {
123       return Status::Invalid("Buffer ", buffer_index_,
124                              " did not start on 8-byte aligned offset: ", offset);
125     }
126     return file_->ReadAt(offset, length).Value(out);
127   }
128 
LoadType(const DataType & type)129   Status LoadType(const DataType& type) { return VisitTypeInline(type, this); }
130 
Load(const Field * field,ArrayData * out)131   Status Load(const Field* field, ArrayData* out) {
132     if (max_recursion_depth_ <= 0) {
133       return Status::Invalid("Max recursion depth reached");
134     }
135 
136     field_ = field;
137     out_ = out;
138     out_->type = field_->type();
139     return LoadType(*field_->type());
140   }
141 
SkipField(const Field * field)142   Status SkipField(const Field* field) {
143     ArrayData dummy;
144     skip_io_ = true;
145     Status status = Load(field, &dummy);
146     skip_io_ = false;
147     return status;
148   }
149 
GetBuffer(int buffer_index,std::shared_ptr<Buffer> * out)150   Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) {
151     auto buffers = metadata_->buffers();
152     CHECK_FLATBUFFERS_NOT_NULL(buffers, "RecordBatch.buffers");
153     if (buffer_index >= static_cast<int>(buffers->size())) {
154       return Status::IOError("buffer_index out of range.");
155     }
156     const flatbuf::Buffer* buffer = buffers->Get(buffer_index);
157     if (buffer->length() == 0) {
158       // Should never return a null buffer here.
159       // (zero-sized buffer allocations are cheap)
160       return AllocateBuffer(0).Value(out);
161     } else {
162       return ReadBuffer(buffer->offset(), buffer->length(), out);
163     }
164   }
165 
GetFieldMetadata(int field_index,ArrayData * out)166   Status GetFieldMetadata(int field_index, ArrayData* out) {
167     auto nodes = metadata_->nodes();
168     CHECK_FLATBUFFERS_NOT_NULL(nodes, "Table.nodes");
169     // pop off a field
170     if (field_index >= static_cast<int>(nodes->size())) {
171       return Status::Invalid("Ran out of field metadata, likely malformed");
172     }
173     const flatbuf::FieldNode* node = nodes->Get(field_index);
174 
175     out->length = node->length();
176     out->null_count = node->null_count();
177     out->offset = 0;
178     return Status::OK();
179   }
180 
LoadCommon()181   Status LoadCommon() {
182     // This only contains the length and null count, which we need to figure
183     // out what to do with the buffers. For example, if null_count == 0, then
184     // we can skip that buffer without reading from shared memory
185     RETURN_NOT_OK(GetFieldMetadata(field_index_++, out_));
186 
187     // extract null_bitmap which is common to all arrays
188     if (out_->null_count == 0) {
189       out_->buffers[0] = nullptr;
190     } else {
191       RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[0]));
192     }
193     buffer_index_++;
194     return Status::OK();
195   }
196 
197   template <typename TYPE>
LoadPrimitive()198   Status LoadPrimitive() {
199     out_->buffers.resize(2);
200 
201     RETURN_NOT_OK(LoadCommon());
202     if (out_->length > 0) {
203       RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
204     } else {
205       buffer_index_++;
206       out_->buffers[1].reset(new Buffer(nullptr, 0));
207     }
208     return Status::OK();
209   }
210 
211   template <typename TYPE>
LoadBinary()212   Status LoadBinary() {
213     out_->buffers.resize(3);
214 
215     RETURN_NOT_OK(LoadCommon());
216     RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
217     return GetBuffer(buffer_index_++, &out_->buffers[2]);
218   }
219 
220   template <typename TYPE>
LoadList(const TYPE & type)221   Status LoadList(const TYPE& type) {
222     out_->buffers.resize(2);
223 
224     RETURN_NOT_OK(LoadCommon());
225     RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
226 
227     const int num_children = type.num_fields();
228     if (num_children != 1) {
229       return Status::Invalid("Wrong number of children: ", num_children);
230     }
231 
232     return LoadChildren(type.fields());
233   }
234 
LoadChildren(std::vector<std::shared_ptr<Field>> child_fields)235   Status LoadChildren(std::vector<std::shared_ptr<Field>> child_fields) {
236     ArrayData* parent = out_;
237     parent->child_data.reserve(static_cast<int>(child_fields.size()));
238     for (const auto& child_field : child_fields) {
239       auto field_array = std::make_shared<ArrayData>();
240       --max_recursion_depth_;
241       RETURN_NOT_OK(Load(child_field.get(), field_array.get()));
242       ++max_recursion_depth_;
243       parent->child_data.emplace_back(field_array);
244     }
245     out_ = parent;
246     return Status::OK();
247   }
248 
Visit(const NullType & type)249   Status Visit(const NullType& type) {
250     out_->buffers.resize(1);
251 
252     // ARROW-6379: NullType has no buffers in the IPC payload
253     return GetFieldMetadata(field_index_++, out_);
254   }
255 
256   template <typename T>
257   enable_if_t<std::is_base_of<FixedWidthType, T>::value &&
258                   !std::is_base_of<FixedSizeBinaryType, T>::value &&
259                   !std::is_base_of<DictionaryType, T>::value,
260               Status>
Visit(const T & type)261   Visit(const T& type) {
262     return LoadPrimitive<T>();
263   }
264 
265   template <typename T>
Visit(const T & type)266   enable_if_base_binary<T, Status> Visit(const T& type) {
267     return LoadBinary<T>();
268   }
269 
Visit(const FixedSizeBinaryType & type)270   Status Visit(const FixedSizeBinaryType& type) {
271     out_->buffers.resize(2);
272     RETURN_NOT_OK(LoadCommon());
273     return GetBuffer(buffer_index_++, &out_->buffers[1]);
274   }
275 
276   template <typename T>
Visit(const T & type)277   enable_if_var_size_list<T, Status> Visit(const T& type) {
278     return LoadList(type);
279   }
280 
Visit(const MapType & type)281   Status Visit(const MapType& type) {
282     RETURN_NOT_OK(LoadList(type));
283     return MapArray::ValidateChildData(out_->child_data);
284   }
285 
Visit(const FixedSizeListType & type)286   Status Visit(const FixedSizeListType& type) {
287     out_->buffers.resize(1);
288 
289     RETURN_NOT_OK(LoadCommon());
290 
291     const int num_children = type.num_fields();
292     if (num_children != 1) {
293       return Status::Invalid("Wrong number of children: ", num_children);
294     }
295 
296     return LoadChildren(type.fields());
297   }
298 
Visit(const StructType & type)299   Status Visit(const StructType& type) {
300     out_->buffers.resize(1);
301     RETURN_NOT_OK(LoadCommon());
302     return LoadChildren(type.fields());
303   }
304 
Visit(const UnionType & type)305   Status Visit(const UnionType& type) {
306     out_->buffers.resize(3);
307 
308     RETURN_NOT_OK(LoadCommon());
309     if (out_->length > 0) {
310       RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[1]));
311       if (type.mode() == UnionMode::DENSE) {
312         RETURN_NOT_OK(GetBuffer(buffer_index_ + 1, &out_->buffers[2]));
313       }
314     }
315     buffer_index_ += type.mode() == UnionMode::DENSE ? 2 : 1;
316     return LoadChildren(type.fields());
317   }
318 
Visit(const DictionaryType & type)319   Status Visit(const DictionaryType& type) {
320     RETURN_NOT_OK(LoadType(*type.index_type()));
321 
322     // Look up dictionary
323     int64_t id = -1;
324     RETURN_NOT_OK(dictionary_memo_->GetId(field_, &id));
325     RETURN_NOT_OK(dictionary_memo_->GetDictionary(id, &out_->dictionary));
326 
327     return Status::OK();
328   }
329 
Visit(const ExtensionType & type)330   Status Visit(const ExtensionType& type) { return LoadType(*type.storage_type()); }
331 
332  private:
333   const flatbuf::RecordBatch* metadata_;
334   io::RandomAccessFile* file_;
335   const DictionaryMemo* dictionary_memo_;
336   int max_recursion_depth_;
337   int buffer_index_ = 0;
338   int field_index_ = 0;
339   bool skip_io_ = false;
340 
341   const Field* field_;
342   ArrayData* out_;
343 };
344 
DecompressBuffers(Compression::type compression,const IpcReadOptions & options,std::vector<std::shared_ptr<ArrayData>> * fields)345 Status DecompressBuffers(Compression::type compression, const IpcReadOptions& options,
346                          std::vector<std::shared_ptr<ArrayData>>* fields) {
347   std::unique_ptr<util::Codec> codec;
348   ARROW_ASSIGN_OR_RAISE(codec, util::Codec::Create(compression));
349 
350   auto DecompressOne = [&](int i) {
351     ArrayData* arr = (*fields)[i].get();
352     for (size_t i = 0; i < arr->buffers.size(); ++i) {
353       if (arr->buffers[i] == nullptr) {
354         continue;
355       }
356       if (arr->buffers[i]->size() == 0) {
357         continue;
358       }
359       if (arr->buffers[i]->size() < 8) {
360         return Status::Invalid(
361             "Likely corrupted message, compressed buffers "
362             "are larger than 8 bytes by construction");
363       }
364       const uint8_t* data = arr->buffers[i]->data();
365       int64_t compressed_size = arr->buffers[i]->size() - sizeof(int64_t);
366       int64_t uncompressed_size =
367           BitUtil::FromLittleEndian(util::SafeLoadAs<int64_t>(data));
368 
369       ARROW_ASSIGN_OR_RAISE(auto uncompressed,
370                             AllocateBuffer(uncompressed_size, options.memory_pool));
371 
372       int64_t actual_decompressed;
373       ARROW_ASSIGN_OR_RAISE(
374           actual_decompressed,
375           codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size,
376                             uncompressed->mutable_data()));
377       if (actual_decompressed != uncompressed_size) {
378         return Status::Invalid("Failed to fully decompress buffer, expected ",
379                                uncompressed_size, " bytes but decompressed ",
380                                actual_decompressed);
381       }
382       arr->buffers[i] = std::move(uncompressed);
383     }
384     return Status::OK();
385   };
386 
387   return ::arrow::internal::OptionalParallelFor(
388       options.use_threads, static_cast<int>(fields->size()), DecompressOne);
389 }
390 
LoadRecordBatchSubset(const flatbuf::RecordBatch * metadata,const std::shared_ptr<Schema> & schema,const std::vector<bool> & inclusion_mask,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,Compression::type compression,io::RandomAccessFile * file)391 Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
392     const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
393     const std::vector<bool>& inclusion_mask, const DictionaryMemo* dictionary_memo,
394     const IpcReadOptions& options, Compression::type compression,
395     io::RandomAccessFile* file) {
396   ArrayLoader loader(metadata, dictionary_memo, options, file);
397 
398   std::vector<std::shared_ptr<ArrayData>> field_data;
399   std::vector<std::shared_ptr<Field>> schema_fields;
400 
401   for (int i = 0; i < schema->num_fields(); ++i) {
402     if (inclusion_mask[i]) {
403       // Read field
404       auto arr = std::make_shared<ArrayData>();
405       RETURN_NOT_OK(loader.Load(schema->field(i).get(), arr.get()));
406       if (metadata->length() != arr->length) {
407         return Status::IOError("Array length did not match record batch length");
408       }
409       field_data.emplace_back(std::move(arr));
410       schema_fields.emplace_back(schema->field(i));
411     } else {
412       // Skip field. This logic must be executed to advance the state of the
413       // loader to the next field
414       RETURN_NOT_OK(loader.SkipField(schema->field(i).get()));
415     }
416   }
417 
418   if (compression != Compression::UNCOMPRESSED) {
419     RETURN_NOT_OK(DecompressBuffers(compression, options, &field_data));
420   }
421 
422   return RecordBatch::Make(::arrow::schema(std::move(schema_fields), schema->metadata()),
423                            metadata->length(), std::move(field_data));
424 }
425 
LoadRecordBatch(const flatbuf::RecordBatch * metadata,const std::shared_ptr<Schema> & schema,const std::vector<bool> & inclusion_mask,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,Compression::type compression,io::RandomAccessFile * file)426 Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
427     const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
428     const std::vector<bool>& inclusion_mask, const DictionaryMemo* dictionary_memo,
429     const IpcReadOptions& options, Compression::type compression,
430     io::RandomAccessFile* file) {
431   if (inclusion_mask.size() > 0) {
432     return LoadRecordBatchSubset(metadata, schema, inclusion_mask, dictionary_memo,
433                                  options, compression, file);
434   }
435 
436   ArrayLoader loader(metadata, dictionary_memo, options, file);
437   std::vector<std::shared_ptr<ArrayData>> arrays(schema->num_fields());
438   for (int i = 0; i < schema->num_fields(); ++i) {
439     auto arr = std::make_shared<ArrayData>();
440     RETURN_NOT_OK(loader.Load(schema->field(i).get(), arr.get()));
441     if (metadata->length() != arr->length) {
442       return Status::IOError("Array length did not match record batch length");
443     }
444     arrays[i] = std::move(arr);
445   }
446   if (compression != Compression::UNCOMPRESSED) {
447     RETURN_NOT_OK(DecompressBuffers(compression, options, &arrays));
448   }
449   return RecordBatch::Make(schema, metadata->length(), std::move(arrays));
450 }
451 
452 // ----------------------------------------------------------------------
453 // Array loading
454 
GetCompression(const flatbuf::Message * message,Compression::type * out)455 Status GetCompression(const flatbuf::Message* message, Compression::type* out) {
456   *out = Compression::UNCOMPRESSED;
457   if (message->custom_metadata() != nullptr) {
458     // TODO: Ensure this deserialization only ever happens once
459     std::shared_ptr<KeyValueMetadata> metadata;
460     RETURN_NOT_OK(internal::GetKeyValueMetadata(message->custom_metadata(), &metadata));
461     int index = metadata->FindKey("ARROW:experimental_compression");
462     if (index != -1) {
463       ARROW_ASSIGN_OR_RAISE(*out,
464                             util::Codec::GetCompressionType(metadata->value(index)));
465     }
466     RETURN_NOT_OK(internal::CheckCompressionSupported(*out));
467   }
468   return Status::OK();
469 }
470 
ReadContiguousPayload(io::InputStream * file,std::unique_ptr<Message> * message)471 static Status ReadContiguousPayload(io::InputStream* file,
472                                     std::unique_ptr<Message>* message) {
473   ARROW_ASSIGN_OR_RAISE(*message, ReadMessage(file));
474   if (*message == nullptr) {
475     return Status::Invalid("Unable to read metadata at offset");
476   }
477   return Status::OK();
478 }
479 
ReadRecordBatch(const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::InputStream * file)480 Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
481     const std::shared_ptr<Schema>& schema, const DictionaryMemo* dictionary_memo,
482     const IpcReadOptions& options, io::InputStream* file) {
483   std::unique_ptr<Message> message;
484   RETURN_NOT_OK(ReadContiguousPayload(file, &message));
485   CHECK_HAS_BODY(*message);
486   ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
487   return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options,
488                          reader.get());
489 }
490 
ReadRecordBatch(const Message & message,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options)491 Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
492     const Message& message, const std::shared_ptr<Schema>& schema,
493     const DictionaryMemo* dictionary_memo, const IpcReadOptions& options) {
494   CHECK_MESSAGE_TYPE(Message::RECORD_BATCH, message.type());
495   CHECK_HAS_BODY(message);
496   ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
497   return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options,
498                          reader.get());
499 }
500 
ReadRecordBatchInternal(const Buffer & metadata,const std::shared_ptr<Schema> & schema,const std::vector<bool> & inclusion_mask,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)501 Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal(
502     const Buffer& metadata, const std::shared_ptr<Schema>& schema,
503     const std::vector<bool>& inclusion_mask, const DictionaryMemo* dictionary_memo,
504     const IpcReadOptions& options, io::RandomAccessFile* file) {
505   const flatbuf::Message* message = nullptr;
506   RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
507   auto batch = message->header_as_RecordBatch();
508   if (batch == nullptr) {
509     return Status::IOError(
510         "Header-type of flatbuffer-encoded Message is not RecordBatch.");
511   }
512   Compression::type compression;
513   RETURN_NOT_OK(GetCompression(message, &compression));
514   return LoadRecordBatch(batch, schema, inclusion_mask, dictionary_memo, options,
515                          compression, file);
516 }
517 
518 // If we are selecting only certain fields, populate an inclusion mask for fast lookups.
519 // Additionally, drop deselected fields from the reader's schema.
GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema> & full_schema,const std::vector<int> & included_indices,std::vector<bool> * inclusion_mask,std::shared_ptr<Schema> * out_schema)520 Status GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema>& full_schema,
521                                     const std::vector<int>& included_indices,
522                                     std::vector<bool>* inclusion_mask,
523                                     std::shared_ptr<Schema>* out_schema) {
524   inclusion_mask->clear();
525   if (included_indices.empty()) {
526     *out_schema = full_schema;
527     return Status::OK();
528   }
529 
530   inclusion_mask->resize(full_schema->num_fields(), false);
531 
532   auto included_indices_sorted = included_indices;
533   std::sort(included_indices_sorted.begin(), included_indices_sorted.end());
534 
535   FieldVector included_fields;
536   for (int i : included_indices_sorted) {
537     // Ignore out of bounds indices
538     if (i < 0 || i >= full_schema->num_fields()) {
539       return Status::Invalid("Out of bounds field index: ", i);
540     }
541 
542     if (inclusion_mask->at(i)) continue;
543 
544     inclusion_mask->at(i) = true;
545     included_fields.push_back(full_schema->field(i));
546   }
547 
548   *out_schema = schema(std::move(included_fields), full_schema->metadata());
549   return Status::OK();
550 }
551 
UnpackSchemaMessage(const void * opaque_schema,const IpcReadOptions & options,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * schema,std::shared_ptr<Schema> * out_schema,std::vector<bool> * field_inclusion_mask)552 Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& options,
553                            DictionaryMemo* dictionary_memo,
554                            std::shared_ptr<Schema>* schema,
555                            std::shared_ptr<Schema>* out_schema,
556                            std::vector<bool>* field_inclusion_mask) {
557   RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema));
558 
559   // If we are selecting only certain fields, populate the inclusion mask now
560   // for fast lookups
561   return GetInclusionMaskAndOutSchema(*schema, options.included_fields,
562                                       field_inclusion_mask, out_schema);
563 }
564 
UnpackSchemaMessage(const Message & message,const IpcReadOptions & options,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * schema,std::shared_ptr<Schema> * out_schema,std::vector<bool> * field_inclusion_mask)565 Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options,
566                            DictionaryMemo* dictionary_memo,
567                            std::shared_ptr<Schema>* schema,
568                            std::shared_ptr<Schema>* out_schema,
569                            std::vector<bool>* field_inclusion_mask) {
570   CHECK_MESSAGE_TYPE(Message::SCHEMA, message.type());
571   CHECK_HAS_NO_BODY(message);
572 
573   return UnpackSchemaMessage(message.header(), options, dictionary_memo, schema,
574                              out_schema, field_inclusion_mask);
575 }
576 
ReadRecordBatch(const Buffer & metadata,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)577 Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
578     const Buffer& metadata, const std::shared_ptr<Schema>& schema,
579     const DictionaryMemo* dictionary_memo, const IpcReadOptions& options,
580     io::RandomAccessFile* file) {
581   std::shared_ptr<Schema> out_schema;
582   // Empty means do not use
583   std::vector<bool> inclusion_mask;
584   RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, options.included_fields,
585                                              &inclusion_mask, &out_schema));
586   return ReadRecordBatchInternal(metadata, schema, inclusion_mask, dictionary_memo,
587                                  options, file);
588 }
589 
ReadDictionary(const Buffer & metadata,DictionaryMemo * dictionary_memo,const IpcReadOptions & options,io::RandomAccessFile * file)590 Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo,
591                       const IpcReadOptions& options, io::RandomAccessFile* file) {
592   const flatbuf::Message* message = nullptr;
593   RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
594   auto dictionary_batch = message->header_as_DictionaryBatch();
595   if (dictionary_batch == nullptr) {
596     return Status::IOError(
597         "Header-type of flatbuffer-encoded Message is not DictionaryBatch.");
598   }
599 
600   Compression::type compression;
601   RETURN_NOT_OK(GetCompression(message, &compression));
602 
603   int64_t id = dictionary_batch->id();
604 
605   // Look up the field, which must have been added to the
606   // DictionaryMemo already prior to invoking this function
607   std::shared_ptr<DataType> value_type;
608   RETURN_NOT_OK(dictionary_memo->GetDictionaryType(id, &value_type));
609 
610   auto value_field = ::arrow::field("dummy", value_type);
611 
612   // The dictionary is embedded in a record batch with a single column
613   auto batch_meta = dictionary_batch->data();
614   CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data");
615 
616   std::shared_ptr<RecordBatch> batch;
617   ARROW_ASSIGN_OR_RAISE(
618       batch, LoadRecordBatch(batch_meta, ::arrow::schema({value_field}),
619                              /*field_inclusion_mask=*/{}, dictionary_memo, options,
620                              compression, file));
621   if (batch->num_columns() != 1) {
622     return Status::Invalid("Dictionary record batch must only contain one field");
623   }
624   auto dictionary = batch->column(0);
625   return dictionary_memo->AddDictionary(id, dictionary);
626 }
627 
ParseDictionary(const Message & message,DictionaryMemo * dictionary_memo,const IpcReadOptions & options)628 Status ParseDictionary(const Message& message, DictionaryMemo* dictionary_memo,
629                        const IpcReadOptions& options) {
630   // Only invoke this method if we already know we have a dictionary message
631   DCHECK_EQ(message.type(), Message::DICTIONARY_BATCH);
632   CHECK_HAS_BODY(message);
633   ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
634   return ReadDictionary(*message.metadata(), dictionary_memo, options, reader.get());
635 }
636 
UpdateDictionaries(const Message & message,DictionaryMemo * dictionary_memo,const IpcReadOptions & options)637 Status UpdateDictionaries(const Message& message, DictionaryMemo* dictionary_memo,
638                           const IpcReadOptions& options) {
639   // TODO(wesm): implement delta dictionaries
640   return Status::NotImplemented("Delta dictionaries not yet implemented");
641 }
642 
643 // ----------------------------------------------------------------------
644 // RecordBatchStreamReader implementation
645 
646 class RecordBatchStreamReaderImpl : public RecordBatchStreamReader {
647  public:
Open(std::unique_ptr<MessageReader> message_reader,const IpcReadOptions & options)648   Status Open(std::unique_ptr<MessageReader> message_reader,
649               const IpcReadOptions& options) {
650     message_reader_ = std::move(message_reader);
651     options_ = options;
652 
653     // Read schema
654     ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
655                           message_reader_->ReadNextMessage());
656     if (!message) {
657       return Status::Invalid("Tried reading schema message, was null or length 0");
658     }
659 
660     return UnpackSchemaMessage(*message, options, &dictionary_memo_, &schema_,
661                                &out_schema_, &field_inclusion_mask_);
662   }
663 
ReadNext(std::shared_ptr<RecordBatch> * batch)664   Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
665     if (!have_read_initial_dictionaries_) {
666       RETURN_NOT_OK(ReadInitialDictionaries());
667     }
668 
669     if (empty_stream_) {
670       // ARROW-6006: Degenerate case where stream contains no data, we do not
671       // bother trying to read a RecordBatch message from the stream
672       *batch = nullptr;
673       return Status::OK();
674     }
675 
676     ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
677                           message_reader_->ReadNextMessage());
678     if (message == nullptr) {
679       // End of stream
680       *batch = nullptr;
681       return Status::OK();
682     }
683 
684     if (message->type() == Message::DICTIONARY_BATCH) {
685       return UpdateDictionaries(*message, &dictionary_memo_, options_);
686     } else {
687       CHECK_HAS_BODY(*message);
688       ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
689       return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
690                                      &dictionary_memo_, options_, reader.get())
691           .Value(batch);
692     }
693   }
694 
schema() const695   std::shared_ptr<Schema> schema() const override { return out_schema_; }
696 
697  private:
ReadInitialDictionaries()698   Status ReadInitialDictionaries() {
699     // We must receive all dictionaries before reconstructing the
700     // first record batch. Subsequent dictionary deltas modify the memo
701     std::unique_ptr<Message> message;
702 
703     // TODO(wesm): In future, we may want to reconcile the ids in the stream with
704     // those found in the schema
705     for (int i = 0; i < dictionary_memo_.num_fields(); ++i) {
706       ARROW_ASSIGN_OR_RAISE(message, message_reader_->ReadNextMessage());
707       if (!message) {
708         if (i == 0) {
709           /// ARROW-6006: If we fail to find any dictionaries in the stream, then
710           /// it may be that the stream has a schema but no actual data. In such
711           /// case we communicate that we were unable to find the dictionaries
712           /// (but there was no failure otherwise), so the caller can decide what
713           /// to do
714           empty_stream_ = true;
715           break;
716         } else {
717           // ARROW-6126, the stream terminated before receiving the expected
718           // number of dictionaries
719           return Status::Invalid("IPC stream ended without reading the expected number (",
720                                  dictionary_memo_.num_fields(), ") of dictionaries");
721         }
722       }
723 
724       if (message->type() != Message::DICTIONARY_BATCH) {
725         return Status::Invalid("IPC stream did not have the expected number (",
726                                dictionary_memo_.num_fields(),
727                                ") of dictionaries at the start of the stream");
728       }
729       RETURN_NOT_OK(ParseDictionary(*message, &dictionary_memo_, options_));
730     }
731 
732     have_read_initial_dictionaries_ = true;
733     return Status::OK();
734   }
735 
736   std::unique_ptr<MessageReader> message_reader_;
737   IpcReadOptions options_;
738   std::vector<bool> field_inclusion_mask_;
739 
740   bool have_read_initial_dictionaries_ = false;
741 
742   // Flag to set in case where we fail to observe all dictionaries in a stream,
743   // and so the reader should not attempt to parse any messages
744   bool empty_stream_ = false;
745 
746   DictionaryMemo dictionary_memo_;
747   std::shared_ptr<Schema> schema_, out_schema_;
748 };
749 
750 // ----------------------------------------------------------------------
751 // Stream reader constructors
752 
Open(std::unique_ptr<MessageReader> message_reader,const IpcReadOptions & options)753 Result<std::shared_ptr<RecordBatchReader>> RecordBatchStreamReader::Open(
754     std::unique_ptr<MessageReader> message_reader, const IpcReadOptions& options) {
755   // Private ctor
756   auto result = std::make_shared<RecordBatchStreamReaderImpl>();
757   RETURN_NOT_OK(result->Open(std::move(message_reader), options));
758   return result;
759 }
760 
Open(io::InputStream * stream,const IpcReadOptions & options)761 Result<std::shared_ptr<RecordBatchReader>> RecordBatchStreamReader::Open(
762     io::InputStream* stream, const IpcReadOptions& options) {
763   return Open(MessageReader::Open(stream), options);
764 }
765 
Open(const std::shared_ptr<io::InputStream> & stream,const IpcReadOptions & options)766 Result<std::shared_ptr<RecordBatchReader>> RecordBatchStreamReader::Open(
767     const std::shared_ptr<io::InputStream>& stream, const IpcReadOptions& options) {
768   return Open(MessageReader::Open(stream), options);
769 }
770 
771 // ----------------------------------------------------------------------
772 // Reader implementation
773 
FileBlockFromFlatbuffer(const flatbuf::Block * block)774 static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) {
775   return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()};
776 }
777 
778 class RecordBatchFileReaderImpl : public RecordBatchFileReader {
779  public:
RecordBatchFileReaderImpl()780   RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) {}
781 
num_record_batches() const782   int num_record_batches() const override {
783     return static_cast<int>(internal::FlatBuffersVectorSize(footer_->recordBatches()));
784   }
785 
version() const786   MetadataVersion version() const override {
787     return internal::GetMetadataVersion(footer_->version());
788   }
789 
ReadRecordBatch(int i)790   Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(int i) override {
791     DCHECK_GE(i, 0);
792     DCHECK_LT(i, num_record_batches());
793 
794     if (!read_dictionaries_) {
795       RETURN_NOT_OK(ReadDictionaries());
796       read_dictionaries_ = true;
797     }
798 
799     std::unique_ptr<Message> message;
800     RETURN_NOT_OK(ReadMessageFromBlock(GetRecordBatchBlock(i), &message));
801 
802     CHECK_HAS_BODY(*message);
803     ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
804     return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &dictionary_memo_,
805                                          options_, reader.get());
806   }
807 
Open(const std::shared_ptr<io::RandomAccessFile> & file,int64_t footer_offset,const IpcReadOptions & options)808   Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
809               const IpcReadOptions& options) {
810     owned_file_ = file;
811     return Open(file.get(), footer_offset, options);
812   }
813 
Open(io::RandomAccessFile * file,int64_t footer_offset,const IpcReadOptions & options)814   Status Open(io::RandomAccessFile* file, int64_t footer_offset,
815               const IpcReadOptions& options) {
816     file_ = file;
817     options_ = options;
818     footer_offset_ = footer_offset;
819     RETURN_NOT_OK(ReadFooter());
820 
821     // Get the schema and record any observed dictionaries
822     return UnpackSchemaMessage(footer_->schema(), options, &dictionary_memo_, &schema_,
823                                &out_schema_, &field_inclusion_mask_);
824   }
825 
schema() const826   std::shared_ptr<Schema> schema() const override { return out_schema_; }
827 
metadata() const828   std::shared_ptr<const KeyValueMetadata> metadata() const override { return metadata_; }
829 
830  private:
GetRecordBatchBlock(int i) const831   FileBlock GetRecordBatchBlock(int i) const {
832     return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i));
833   }
834 
GetDictionaryBlock(int i) const835   FileBlock GetDictionaryBlock(int i) const {
836     return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i));
837   }
838 
ReadMessageFromBlock(const FileBlock & block,std::unique_ptr<Message> * out)839   Status ReadMessageFromBlock(const FileBlock& block, std::unique_ptr<Message>* out) {
840     if (!BitUtil::IsMultipleOf8(block.offset) ||
841         !BitUtil::IsMultipleOf8(block.metadata_length) ||
842         !BitUtil::IsMultipleOf8(block.body_length)) {
843       return Status::Invalid("Unaligned block in IPC file");
844     }
845 
846     // TODO(wesm): this breaks integration tests, see ARROW-3256
847     // DCHECK_EQ((*out)->body_length(), block.body_length);
848 
849     return ReadMessage(block.offset, block.metadata_length, file_).Value(out);
850   }
851 
ReadDictionaries()852   Status ReadDictionaries() {
853     // Read all the dictionaries
854     for (int i = 0; i < num_dictionaries(); ++i) {
855       std::unique_ptr<Message> message;
856       RETURN_NOT_OK(ReadMessageFromBlock(GetDictionaryBlock(i), &message));
857 
858       CHECK_HAS_BODY(*message);
859       ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
860       RETURN_NOT_OK(ReadDictionary(*message->metadata(), &dictionary_memo_, options_,
861                                    reader.get()));
862     }
863     return Status::OK();
864   }
865 
ReadFooter()866   Status ReadFooter() {
867     const int32_t magic_size = static_cast<int>(strlen(kArrowMagicBytes));
868 
869     if (footer_offset_ <= magic_size * 2 + 4) {
870       return Status::Invalid("File is too small: ", footer_offset_);
871     }
872 
873     int file_end_size = static_cast<int>(magic_size + sizeof(int32_t));
874     ARROW_ASSIGN_OR_RAISE(auto buffer,
875                           file_->ReadAt(footer_offset_ - file_end_size, file_end_size));
876 
877     const int64_t expected_footer_size = magic_size + sizeof(int32_t);
878     if (buffer->size() < expected_footer_size) {
879       return Status::Invalid("Unable to read ", expected_footer_size, "from end of file");
880     }
881 
882     if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) {
883       return Status::Invalid("Not an Arrow file");
884     }
885 
886     int32_t footer_length = *reinterpret_cast<const int32_t*>(buffer->data());
887 
888     if (footer_length <= 0 || footer_length > footer_offset_ - magic_size * 2 - 4) {
889       return Status::Invalid("File is smaller than indicated metadata size");
890     }
891 
892     // Now read the footer
893     ARROW_ASSIGN_OR_RAISE(
894         footer_buffer_,
895         file_->ReadAt(footer_offset_ - footer_length - file_end_size, footer_length));
896 
897     auto data = footer_buffer_->data();
898     flatbuffers::Verifier verifier(data, footer_buffer_->size(), 128);
899     if (!flatbuf::VerifyFooterBuffer(verifier)) {
900       return Status::IOError("Verification of flatbuffer-encoded Footer failed.");
901     }
902     footer_ = flatbuf::GetFooter(data);
903 
904     auto fb_metadata = footer_->custom_metadata();
905     if (fb_metadata != nullptr) {
906       std::shared_ptr<KeyValueMetadata> md;
907       RETURN_NOT_OK(internal::GetKeyValueMetadata(fb_metadata, &md));
908       metadata_ = std::move(md);  // const-ify
909     }
910 
911     return Status::OK();
912   }
913 
num_dictionaries() const914   int num_dictionaries() const {
915     return static_cast<int>(internal::FlatBuffersVectorSize(footer_->dictionaries()));
916   }
917 
918   io::RandomAccessFile* file_;
919   IpcReadOptions options_;
920   std::vector<bool> field_inclusion_mask_;
921 
922   std::shared_ptr<io::RandomAccessFile> owned_file_;
923 
924   // The location where the Arrow file layout ends. May be the end of the file
925   // or some other location if embedded in a larger file.
926   int64_t footer_offset_;
927 
928   // Footer metadata
929   std::shared_ptr<Buffer> footer_buffer_;
930   const flatbuf::Footer* footer_;
931   std::shared_ptr<const KeyValueMetadata> metadata_;
932 
933   bool read_dictionaries_ = false;
934   DictionaryMemo dictionary_memo_;
935 
936   // Reconstructed schema, including any read dictionaries
937   std::shared_ptr<Schema> schema_;
938   // Schema with deselected fields dropped
939   std::shared_ptr<Schema> out_schema_;
940 };
941 
Open(io::RandomAccessFile * file,const IpcReadOptions & options)942 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
943     io::RandomAccessFile* file, const IpcReadOptions& options) {
944   ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
945   return Open(file, footer_offset, options);
946 }
947 
Open(io::RandomAccessFile * file,int64_t footer_offset,const IpcReadOptions & options)948 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
949     io::RandomAccessFile* file, int64_t footer_offset, const IpcReadOptions& options) {
950   auto result = std::make_shared<RecordBatchFileReaderImpl>();
951   RETURN_NOT_OK(result->Open(file, footer_offset, options));
952   return result;
953 }
954 
Open(const std::shared_ptr<io::RandomAccessFile> & file,const IpcReadOptions & options)955 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
956     const std::shared_ptr<io::RandomAccessFile>& file, const IpcReadOptions& options) {
957   ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
958   return Open(file, footer_offset, options);
959 }
960 
Open(const std::shared_ptr<io::RandomAccessFile> & file,int64_t footer_offset,const IpcReadOptions & options)961 Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
962     const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
963     const IpcReadOptions& options) {
964   auto result = std::make_shared<RecordBatchFileReaderImpl>();
965   RETURN_NOT_OK(result->Open(file, footer_offset, options));
966   return result;
967 }
968 
OnEOS()969 Status Listener::OnEOS() { return Status::OK(); }
970 
OnSchemaDecoded(std::shared_ptr<Schema> schema)971 Status Listener::OnSchemaDecoded(std::shared_ptr<Schema> schema) { return Status::OK(); }
972 
OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch)973 Status Listener::OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch) {
974   return Status::NotImplemented("OnRecordBatchDecoded() callback isn't implemented");
975 }
976 
977 class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener {
978  private:
979   enum State {
980     SCHEMA,
981     INITIAL_DICTIONARIES,
982     RECORD_BATCHES,
983     EOS,
984   };
985 
986  public:
StreamDecoderImpl(std::shared_ptr<Listener> listener,const IpcReadOptions & options)987   explicit StreamDecoderImpl(std::shared_ptr<Listener> listener,
988                              const IpcReadOptions& options)
989       : MessageDecoderListener(),
990         listener_(std::move(listener)),
991         options_(options),
992         state_(State::SCHEMA),
993         message_decoder_(std::shared_ptr<StreamDecoderImpl>(this, [](void*) {}),
994                          options_.memory_pool),
995         field_inclusion_mask_(),
996         n_required_dictionaries_(0),
997         dictionary_memo_(),
998         schema_() {}
999 
OnMessageDecoded(std::unique_ptr<Message> message)1000   Status OnMessageDecoded(std::unique_ptr<Message> message) override {
1001     switch (state_) {
1002       case State::SCHEMA:
1003         ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message)));
1004         break;
1005       case State::INITIAL_DICTIONARIES:
1006         ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message)));
1007         break;
1008       case State::RECORD_BATCHES:
1009         ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message)));
1010         break;
1011       case State::EOS:
1012         break;
1013     }
1014     return Status::OK();
1015   }
1016 
OnEOS()1017   Status OnEOS() override {
1018     state_ = State::EOS;
1019     return listener_->OnEOS();
1020   }
1021 
Consume(const uint8_t * data,int64_t size)1022   Status Consume(const uint8_t* data, int64_t size) {
1023     return message_decoder_.Consume(data, size);
1024   }
1025 
Consume(std::shared_ptr<Buffer> buffer)1026   Status Consume(std::shared_ptr<Buffer> buffer) {
1027     return message_decoder_.Consume(std::move(buffer));
1028   }
1029 
schema() const1030   std::shared_ptr<Schema> schema() const { return out_schema_; }
1031 
next_required_size() const1032   int64_t next_required_size() const { return message_decoder_.next_required_size(); }
1033 
1034  private:
OnSchemaMessageDecoded(std::unique_ptr<Message> message)1035   Status OnSchemaMessageDecoded(std::unique_ptr<Message> message) {
1036     RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, &dictionary_memo_, &schema_,
1037                                       &out_schema_, &field_inclusion_mask_));
1038 
1039     n_required_dictionaries_ = dictionary_memo_.num_fields();
1040     if (n_required_dictionaries_ == 0) {
1041       state_ = State::RECORD_BATCHES;
1042       RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
1043     } else {
1044       state_ = State::INITIAL_DICTIONARIES;
1045     }
1046     return Status::OK();
1047   }
1048 
OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message)1049   Status OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message) {
1050     if (message->type() != Message::DICTIONARY_BATCH) {
1051       return Status::Invalid("IPC stream did not have the expected number (",
1052                              dictionary_memo_.num_fields(),
1053                              ") of dictionaries at the start of the stream");
1054     }
1055     RETURN_NOT_OK(ParseDictionary(*message, &dictionary_memo_, options_));
1056     n_required_dictionaries_--;
1057     if (n_required_dictionaries_ == 0) {
1058       state_ = State::RECORD_BATCHES;
1059       ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
1060     }
1061     return Status::OK();
1062   }
1063 
OnRecordBatchMessageDecoded(std::unique_ptr<Message> message)1064   Status OnRecordBatchMessageDecoded(std::unique_ptr<Message> message) {
1065     if (message->type() == Message::DICTIONARY_BATCH) {
1066       return UpdateDictionaries(*message, &dictionary_memo_, options_);
1067     } else {
1068       CHECK_HAS_BODY(*message);
1069       ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
1070       ARROW_ASSIGN_OR_RAISE(
1071           auto batch,
1072           ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
1073                                   &dictionary_memo_, options_, reader.get()));
1074       return listener_->OnRecordBatchDecoded(std::move(batch));
1075     }
1076   }
1077 
1078   std::shared_ptr<Listener> listener_;
1079   IpcReadOptions options_;
1080   State state_;
1081   MessageDecoder message_decoder_;
1082   std::vector<bool> field_inclusion_mask_;
1083   int n_required_dictionaries_;
1084   DictionaryMemo dictionary_memo_;
1085   std::shared_ptr<Schema> schema_, out_schema_;
1086 };
1087 
StreamDecoder(std::shared_ptr<Listener> listener,const IpcReadOptions & options)1088 StreamDecoder::StreamDecoder(std::shared_ptr<Listener> listener,
1089                              const IpcReadOptions& options) {
1090   impl_.reset(new StreamDecoderImpl(std::move(listener), options));
1091 }
1092 
~StreamDecoder()1093 StreamDecoder::~StreamDecoder() {}
1094 
Consume(const uint8_t * data,int64_t size)1095 Status StreamDecoder::Consume(const uint8_t* data, int64_t size) {
1096   return impl_->Consume(data, size);
1097 }
Consume(std::shared_ptr<Buffer> buffer)1098 Status StreamDecoder::Consume(std::shared_ptr<Buffer> buffer) {
1099   return impl_->Consume(std::move(buffer));
1100 }
1101 
schema() const1102 std::shared_ptr<Schema> StreamDecoder::schema() const { return impl_->schema(); }
1103 
next_required_size() const1104 int64_t StreamDecoder::next_required_size() const { return impl_->next_required_size(); }
1105 
ReadSchema(io::InputStream * stream,DictionaryMemo * dictionary_memo)1106 Result<std::shared_ptr<Schema>> ReadSchema(io::InputStream* stream,
1107                                            DictionaryMemo* dictionary_memo) {
1108   std::unique_ptr<MessageReader> reader = MessageReader::Open(stream);
1109   ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, reader->ReadNextMessage());
1110   if (!message) {
1111     return Status::Invalid("Tried reading schema message, was null or length 0");
1112   }
1113   CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type());
1114   return ReadSchema(*message, dictionary_memo);
1115 }
1116 
ReadSchema(const Message & message,DictionaryMemo * dictionary_memo)1117 Result<std::shared_ptr<Schema>> ReadSchema(const Message& message,
1118                                            DictionaryMemo* dictionary_memo) {
1119   std::shared_ptr<Schema> result;
1120   RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, &result));
1121   return result;
1122 }
1123 
ReadTensor(io::InputStream * file)1124 Result<std::shared_ptr<Tensor>> ReadTensor(io::InputStream* file) {
1125   std::unique_ptr<Message> message;
1126   RETURN_NOT_OK(ReadContiguousPayload(file, &message));
1127   return ReadTensor(*message);
1128 }
1129 
ReadTensor(const Message & message)1130 Result<std::shared_ptr<Tensor>> ReadTensor(const Message& message) {
1131   std::shared_ptr<DataType> type;
1132   std::vector<int64_t> shape;
1133   std::vector<int64_t> strides;
1134   std::vector<std::string> dim_names;
1135   CHECK_HAS_BODY(message);
1136   RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides,
1137                                             &dim_names));
1138   return Tensor::Make(type, message.body(), shape, strides, dim_names);
1139 }
1140 
1141 namespace {
1142 
ReadSparseCOOIndex(const flatbuf::SparseTensor * sparse_tensor,const std::vector<int64_t> & shape,int64_t non_zero_length,io::RandomAccessFile * file)1143 Result<std::shared_ptr<SparseIndex>> ReadSparseCOOIndex(
1144     const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
1145     int64_t non_zero_length, io::RandomAccessFile* file) {
1146   auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCOO();
1147   const auto ndim = static_cast<int64_t>(shape.size());
1148 
1149   std::shared_ptr<DataType> indices_type;
1150   RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(sparse_index, &indices_type));
1151   const int64_t indices_elsize =
1152       checked_cast<const IntegerType&>(*indices_type).bit_width() / 8;
1153 
1154   auto* indices_buffer = sparse_index->indicesBuffer();
1155   ARROW_ASSIGN_OR_RAISE(auto indices_data,
1156                         file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
1157   std::vector<int64_t> indices_shape({non_zero_length, ndim});
1158   auto* indices_strides = sparse_index->indicesStrides();
1159   std::vector<int64_t> strides(2);
1160   if (indices_strides && indices_strides->size() > 0) {
1161     if (indices_strides->size() != 2) {
1162       return Status::Invalid("Wrong size for indicesStrides in SparseCOOIndex");
1163     }
1164     strides[0] = indices_strides->Get(0);
1165     strides[1] = indices_strides->Get(1);
1166   } else {
1167     // Row-major by default
1168     strides[0] = indices_elsize * ndim;
1169     strides[1] = indices_elsize;
1170   }
1171   return std::make_shared<SparseCOOIndex>(
1172       std::make_shared<Tensor>(indices_type, indices_data, indices_shape, strides));
1173 }
1174 
ReadSparseCSXIndex(const flatbuf::SparseTensor * sparse_tensor,const std::vector<int64_t> & shape,int64_t non_zero_length,io::RandomAccessFile * file)1175 Result<std::shared_ptr<SparseIndex>> ReadSparseCSXIndex(
1176     const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
1177     int64_t non_zero_length, io::RandomAccessFile* file) {
1178   if (shape.size() != 2) {
1179     return Status::Invalid("Invalid shape length for a sparse matrix");
1180   }
1181 
1182   auto* sparse_index = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX();
1183 
1184   std::shared_ptr<DataType> indptr_type, indices_type;
1185   RETURN_NOT_OK(
1186       internal::GetSparseCSXIndexMetadata(sparse_index, &indptr_type, &indices_type));
1187 
1188   auto* indptr_buffer = sparse_index->indptrBuffer();
1189   ARROW_ASSIGN_OR_RAISE(auto indptr_data,
1190                         file->ReadAt(indptr_buffer->offset(), indptr_buffer->length()));
1191 
1192   auto* indices_buffer = sparse_index->indicesBuffer();
1193   ARROW_ASSIGN_OR_RAISE(auto indices_data,
1194                         file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
1195 
1196   std::vector<int64_t> indices_shape({non_zero_length});
1197   const auto indices_minimum_bytes =
1198       indices_shape[0] * checked_pointer_cast<FixedWidthType>(indices_type)->bit_width() /
1199       CHAR_BIT;
1200   if (indices_minimum_bytes > indices_buffer->length()) {
1201     return Status::Invalid("shape is inconsistent to the size of indices buffer");
1202   }
1203 
1204   switch (sparse_index->compressedAxis()) {
1205     case flatbuf::SparseMatrixCompressedAxis::Row: {
1206       std::vector<int64_t> indptr_shape({shape[0] + 1});
1207       const int64_t indptr_minimum_bytes =
1208           indptr_shape[0] *
1209           checked_pointer_cast<FixedWidthType>(indptr_type)->bit_width() / CHAR_BIT;
1210       if (indptr_minimum_bytes > indptr_buffer->length()) {
1211         return Status::Invalid("shape is inconsistent to the size of indptr buffer");
1212       }
1213       return std::make_shared<SparseCSRIndex>(
1214           std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
1215           std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
1216     }
1217     case flatbuf::SparseMatrixCompressedAxis::Column: {
1218       std::vector<int64_t> indptr_shape({shape[1] + 1});
1219       const int64_t indptr_minimum_bytes =
1220           indptr_shape[0] *
1221           checked_pointer_cast<FixedWidthType>(indptr_type)->bit_width() / CHAR_BIT;
1222       if (indptr_minimum_bytes > indptr_buffer->length()) {
1223         return Status::Invalid("shape is inconsistent to the size of indptr buffer");
1224       }
1225       return std::make_shared<SparseCSCIndex>(
1226           std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
1227           std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
1228     }
1229     default:
1230       return Status::Invalid("Invalid value of SparseMatrixCompressedAxis");
1231   }
1232 }
1233 
ReadSparseCSFIndex(const flatbuf::SparseTensor * sparse_tensor,const std::vector<int64_t> & shape,io::RandomAccessFile * file)1234 Result<std::shared_ptr<SparseIndex>> ReadSparseCSFIndex(
1235     const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
1236     io::RandomAccessFile* file) {
1237   auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCSF();
1238   const auto ndim = static_cast<int64_t>(shape.size());
1239   auto* indptr_buffers = sparse_index->indptrBuffers();
1240   auto* indices_buffers = sparse_index->indicesBuffers();
1241   std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
1242   std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
1243 
1244   std::shared_ptr<DataType> indptr_type, indices_type;
1245   std::vector<int64_t> axis_order, indices_size;
1246 
1247   RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
1248       sparse_index, &axis_order, &indices_size, &indptr_type, &indices_type));
1249   for (int i = 0; i < static_cast<int>(indptr_buffers->Length()); ++i) {
1250     ARROW_ASSIGN_OR_RAISE(indptr_data[i], file->ReadAt(indptr_buffers->Get(i)->offset(),
1251                                                        indptr_buffers->Get(i)->length()));
1252   }
1253   for (int i = 0; i < static_cast<int>(indices_buffers->Length()); ++i) {
1254     ARROW_ASSIGN_OR_RAISE(indices_data[i],
1255                           file->ReadAt(indices_buffers->Get(i)->offset(),
1256                                        indices_buffers->Get(i)->length()));
1257   }
1258 
1259   return SparseCSFIndex::Make(indptr_type, indices_type, indices_size, axis_order,
1260                               indptr_data, indices_data);
1261 }
1262 
MakeSparseTensorWithSparseCOOIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCOOIndex> & sparse_index,int64_t non_zero_length,const std::shared_ptr<Buffer> & data)1263 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCOOIndex(
1264     const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1265     const std::vector<std::string>& dim_names,
1266     const std::shared_ptr<SparseCOOIndex>& sparse_index, int64_t non_zero_length,
1267     const std::shared_ptr<Buffer>& data) {
1268   return SparseCOOTensor::Make(sparse_index, type, data, shape, dim_names);
1269 }
1270 
MakeSparseTensorWithSparseCSRIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCSRIndex> & sparse_index,int64_t non_zero_length,const std::shared_ptr<Buffer> & data)1271 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSRIndex(
1272     const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1273     const std::vector<std::string>& dim_names,
1274     const std::shared_ptr<SparseCSRIndex>& sparse_index, int64_t non_zero_length,
1275     const std::shared_ptr<Buffer>& data) {
1276   return SparseCSRMatrix::Make(sparse_index, type, data, shape, dim_names);
1277 }
1278 
MakeSparseTensorWithSparseCSCIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCSCIndex> & sparse_index,int64_t non_zero_length,const std::shared_ptr<Buffer> & data)1279 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSCIndex(
1280     const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1281     const std::vector<std::string>& dim_names,
1282     const std::shared_ptr<SparseCSCIndex>& sparse_index, int64_t non_zero_length,
1283     const std::shared_ptr<Buffer>& data) {
1284   return SparseCSCMatrix::Make(sparse_index, type, data, shape, dim_names);
1285 }
1286 
MakeSparseTensorWithSparseCSFIndex(const std::shared_ptr<DataType> & type,const std::vector<int64_t> & shape,const std::vector<std::string> & dim_names,const std::shared_ptr<SparseCSFIndex> & sparse_index,const std::shared_ptr<Buffer> & data)1287 Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSFIndex(
1288     const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
1289     const std::vector<std::string>& dim_names,
1290     const std::shared_ptr<SparseCSFIndex>& sparse_index,
1291     const std::shared_ptr<Buffer>& data) {
1292   return SparseCSFTensor::Make(sparse_index, type, data, shape, dim_names);
1293 }
1294 
ReadSparseTensorMetadata(const Buffer & metadata,std::shared_ptr<DataType> * out_type,std::vector<int64_t> * out_shape,std::vector<std::string> * out_dim_names,int64_t * out_non_zero_length,SparseTensorFormat::type * out_format_id,const flatbuf::SparseTensor ** out_fb_sparse_tensor,const flatbuf::Buffer ** out_buffer)1295 Status ReadSparseTensorMetadata(const Buffer& metadata,
1296                                 std::shared_ptr<DataType>* out_type,
1297                                 std::vector<int64_t>* out_shape,
1298                                 std::vector<std::string>* out_dim_names,
1299                                 int64_t* out_non_zero_length,
1300                                 SparseTensorFormat::type* out_format_id,
1301                                 const flatbuf::SparseTensor** out_fb_sparse_tensor,
1302                                 const flatbuf::Buffer** out_buffer) {
1303   RETURN_NOT_OK(internal::GetSparseTensorMetadata(
1304       metadata, out_type, out_shape, out_dim_names, out_non_zero_length, out_format_id));
1305 
1306   const flatbuf::Message* message;
1307   RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
1308 
1309   auto sparse_tensor = message->header_as_SparseTensor();
1310   if (sparse_tensor == nullptr) {
1311     return Status::IOError(
1312         "Header-type of flatbuffer-encoded Message is not SparseTensor.");
1313   }
1314   *out_fb_sparse_tensor = sparse_tensor;
1315 
1316   auto buffer = sparse_tensor->data();
1317   if (!BitUtil::IsMultipleOf8(buffer->offset())) {
1318     return Status::Invalid(
1319         "Buffer of sparse index data did not start on 8-byte aligned offset: ",
1320         buffer->offset());
1321   }
1322   *out_buffer = buffer;
1323 
1324   return Status::OK();
1325 }
1326 
1327 }  // namespace
1328 
1329 namespace internal {
1330 
1331 namespace {
1332 
GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,const size_t ndim)1333 Result<size_t> GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,
1334                                               const size_t ndim) {
1335   switch (format_id) {
1336     case SparseTensorFormat::COO:
1337       return 2;
1338 
1339     case SparseTensorFormat::CSR:
1340       return 3;
1341 
1342     case SparseTensorFormat::CSC:
1343       return 3;
1344 
1345     case SparseTensorFormat::CSF:
1346       return 2 * ndim;
1347 
1348     default:
1349       return Status::Invalid("Unrecognized sparse tensor format");
1350   }
1351 }
1352 
CheckSparseTensorBodyBufferCount(const IpcPayload & payload,SparseTensorFormat::type sparse_tensor_format_id,const size_t ndim)1353 Status CheckSparseTensorBodyBufferCount(const IpcPayload& payload,
1354                                         SparseTensorFormat::type sparse_tensor_format_id,
1355                                         const size_t ndim) {
1356   size_t expected_body_buffer_count = 0;
1357   ARROW_ASSIGN_OR_RAISE(expected_body_buffer_count,
1358                         GetSparseTensorBodyBufferCount(sparse_tensor_format_id, ndim));
1359   if (payload.body_buffers.size() != expected_body_buffer_count) {
1360     return Status::Invalid("Invalid body buffer count for a sparse tensor");
1361   }
1362 
1363   return Status::OK();
1364 }
1365 
1366 }  // namespace
1367 
ReadSparseTensorBodyBufferCount(const Buffer & metadata)1368 Result<size_t> ReadSparseTensorBodyBufferCount(const Buffer& metadata) {
1369   SparseTensorFormat::type format_id;
1370   std::vector<int64_t> shape;
1371 
1372   RETURN_NOT_OK(internal::GetSparseTensorMetadata(metadata, nullptr, &shape, nullptr,
1373                                                   nullptr, &format_id));
1374 
1375   return GetSparseTensorBodyBufferCount(format_id, static_cast<size_t>(shape.size()));
1376 }
1377 
ReadSparseTensorPayload(const IpcPayload & payload)1378 Result<std::shared_ptr<SparseTensor>> ReadSparseTensorPayload(const IpcPayload& payload) {
1379   std::shared_ptr<DataType> type;
1380   std::vector<int64_t> shape;
1381   std::vector<std::string> dim_names;
1382   int64_t non_zero_length;
1383   SparseTensorFormat::type sparse_tensor_format_id;
1384   const flatbuf::SparseTensor* sparse_tensor;
1385   const flatbuf::Buffer* buffer;
1386 
1387   RETURN_NOT_OK(ReadSparseTensorMetadata(*payload.metadata, &type, &shape, &dim_names,
1388                                          &non_zero_length, &sparse_tensor_format_id,
1389                                          &sparse_tensor, &buffer));
1390 
1391   RETURN_NOT_OK(CheckSparseTensorBodyBufferCount(payload, sparse_tensor_format_id,
1392                                                  static_cast<size_t>(shape.size())));
1393 
1394   switch (sparse_tensor_format_id) {
1395     case SparseTensorFormat::COO: {
1396       std::shared_ptr<SparseCOOIndex> sparse_index;
1397       std::shared_ptr<DataType> indices_type;
1398       RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(
1399           sparse_tensor->sparseIndex_as_SparseTensorIndexCOO(), &indices_type));
1400       ARROW_ASSIGN_OR_RAISE(sparse_index,
1401                             SparseCOOIndex::Make(indices_type, shape, non_zero_length,
1402                                                  payload.body_buffers[0]));
1403       return MakeSparseTensorWithSparseCOOIndex(type, shape, dim_names, sparse_index,
1404                                                 non_zero_length, payload.body_buffers[1]);
1405     }
1406     case SparseTensorFormat::CSR: {
1407       std::shared_ptr<SparseCSRIndex> sparse_index;
1408       std::shared_ptr<DataType> indptr_type;
1409       std::shared_ptr<DataType> indices_type;
1410       RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
1411           sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
1412           &indices_type));
1413       ARROW_CHECK_EQ(indptr_type, indices_type);
1414       ARROW_ASSIGN_OR_RAISE(
1415           sparse_index,
1416           SparseCSRIndex::Make(indices_type, shape, non_zero_length,
1417                                payload.body_buffers[0], payload.body_buffers[1]));
1418       return MakeSparseTensorWithSparseCSRIndex(type, shape, dim_names, sparse_index,
1419                                                 non_zero_length, payload.body_buffers[2]);
1420     }
1421     case SparseTensorFormat::CSC: {
1422       std::shared_ptr<SparseCSCIndex> sparse_index;
1423       std::shared_ptr<DataType> indptr_type;
1424       std::shared_ptr<DataType> indices_type;
1425       RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
1426           sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
1427           &indices_type));
1428       ARROW_CHECK_EQ(indptr_type, indices_type);
1429       ARROW_ASSIGN_OR_RAISE(
1430           sparse_index,
1431           SparseCSCIndex::Make(indices_type, shape, non_zero_length,
1432                                payload.body_buffers[0], payload.body_buffers[1]));
1433       return MakeSparseTensorWithSparseCSCIndex(type, shape, dim_names, sparse_index,
1434                                                 non_zero_length, payload.body_buffers[2]);
1435     }
1436     case SparseTensorFormat::CSF: {
1437       std::shared_ptr<SparseCSFIndex> sparse_index;
1438       std::shared_ptr<DataType> indptr_type, indices_type;
1439       std::vector<int64_t> axis_order, indices_size;
1440 
1441       RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
1442           sparse_tensor->sparseIndex_as_SparseTensorIndexCSF(), &axis_order,
1443           &indices_size, &indptr_type, &indices_type));
1444       ARROW_CHECK_EQ(indptr_type, indices_type);
1445 
1446       const int64_t ndim = shape.size();
1447       std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
1448       std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
1449 
1450       for (int64_t i = 0; i < ndim - 1; ++i) {
1451         indptr_data[i] = payload.body_buffers[i];
1452       }
1453       for (int64_t i = 0; i < ndim; ++i) {
1454         indices_data[i] = payload.body_buffers[i + ndim - 1];
1455       }
1456 
1457       ARROW_ASSIGN_OR_RAISE(sparse_index,
1458                             SparseCSFIndex::Make(indptr_type, indices_type, indices_size,
1459                                                  axis_order, indptr_data, indices_data));
1460       return MakeSparseTensorWithSparseCSFIndex(type, shape, dim_names, sparse_index,
1461                                                 payload.body_buffers[2 * ndim - 1]);
1462     }
1463     default:
1464       return Status::Invalid("Unsupported sparse index format");
1465   }
1466 }
1467 
1468 }  // namespace internal
1469 
ReadSparseTensor(const Buffer & metadata,io::RandomAccessFile * file)1470 Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Buffer& metadata,
1471                                                        io::RandomAccessFile* file) {
1472   std::shared_ptr<DataType> type;
1473   std::vector<int64_t> shape;
1474   std::vector<std::string> dim_names;
1475   int64_t non_zero_length;
1476   SparseTensorFormat::type sparse_tensor_format_id;
1477   const flatbuf::SparseTensor* sparse_tensor;
1478   const flatbuf::Buffer* buffer;
1479 
1480   RETURN_NOT_OK(ReadSparseTensorMetadata(metadata, &type, &shape, &dim_names,
1481                                          &non_zero_length, &sparse_tensor_format_id,
1482                                          &sparse_tensor, &buffer));
1483 
1484   ARROW_ASSIGN_OR_RAISE(auto data, file->ReadAt(buffer->offset(), buffer->length()));
1485 
1486   std::shared_ptr<SparseIndex> sparse_index;
1487   switch (sparse_tensor_format_id) {
1488     case SparseTensorFormat::COO: {
1489       ARROW_ASSIGN_OR_RAISE(
1490           sparse_index, ReadSparseCOOIndex(sparse_tensor, shape, non_zero_length, file));
1491       return MakeSparseTensorWithSparseCOOIndex(
1492           type, shape, dim_names, checked_pointer_cast<SparseCOOIndex>(sparse_index),
1493           non_zero_length, data);
1494     }
1495     case SparseTensorFormat::CSR: {
1496       ARROW_ASSIGN_OR_RAISE(
1497           sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
1498       return MakeSparseTensorWithSparseCSRIndex(
1499           type, shape, dim_names, checked_pointer_cast<SparseCSRIndex>(sparse_index),
1500           non_zero_length, data);
1501     }
1502     case SparseTensorFormat::CSC: {
1503       ARROW_ASSIGN_OR_RAISE(
1504           sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
1505       return MakeSparseTensorWithSparseCSCIndex(
1506           type, shape, dim_names, checked_pointer_cast<SparseCSCIndex>(sparse_index),
1507           non_zero_length, data);
1508     }
1509     case SparseTensorFormat::CSF: {
1510       ARROW_ASSIGN_OR_RAISE(sparse_index, ReadSparseCSFIndex(sparse_tensor, shape, file));
1511       return MakeSparseTensorWithSparseCSFIndex(
1512           type, shape, dim_names, checked_pointer_cast<SparseCSFIndex>(sparse_index),
1513           data);
1514     }
1515     default:
1516       return Status::Invalid("Unsupported sparse index format");
1517   }
1518 }
1519 
ReadSparseTensor(const Message & message)1520 Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Message& message) {
1521   CHECK_HAS_BODY(message);
1522   ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
1523   return ReadSparseTensor(*message.metadata(), reader.get());
1524 }
1525 
ReadSparseTensor(io::InputStream * file)1526 Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(io::InputStream* file) {
1527   std::unique_ptr<Message> message;
1528   RETURN_NOT_OK(ReadContiguousPayload(file, &message));
1529   CHECK_MESSAGE_TYPE(Message::SPARSE_TENSOR, message->type());
1530   CHECK_HAS_BODY(*message);
1531   ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
1532   return ReadSparseTensor(*message->metadata(), reader.get());
1533 }
1534 
1535 ///////////////////////////////////////////////////////////////////////////
1536 // Helpers for fuzzing
1537 
1538 namespace internal {
1539 
FuzzIpcStream(const uint8_t * data,int64_t size)1540 Status FuzzIpcStream(const uint8_t* data, int64_t size) {
1541   auto buffer = std::make_shared<Buffer>(data, size);
1542   io::BufferReader buffer_reader(buffer);
1543 
1544   std::shared_ptr<RecordBatchReader> batch_reader;
1545   ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchStreamReader::Open(&buffer_reader));
1546 
1547   while (true) {
1548     std::shared_ptr<arrow::RecordBatch> batch;
1549     RETURN_NOT_OK(batch_reader->ReadNext(&batch));
1550     if (batch == nullptr) {
1551       break;
1552     }
1553     RETURN_NOT_OK(batch->ValidateFull());
1554   }
1555 
1556   return Status::OK();
1557 }
1558 
FuzzIpcFile(const uint8_t * data,int64_t size)1559 Status FuzzIpcFile(const uint8_t* data, int64_t size) {
1560   auto buffer = std::make_shared<Buffer>(data, size);
1561   io::BufferReader buffer_reader(buffer);
1562 
1563   std::shared_ptr<RecordBatchFileReader> batch_reader;
1564   ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchFileReader::Open(&buffer_reader));
1565 
1566   const int n_batches = batch_reader->num_record_batches();
1567   for (int i = 0; i < n_batches; ++i) {
1568     ARROW_ASSIGN_OR_RAISE(auto batch, batch_reader->ReadRecordBatch(i));
1569     RETURN_NOT_OK(batch->ValidateFull());
1570   }
1571 
1572   return Status::OK();
1573 }
1574 
1575 }  // namespace internal
1576 
1577 // ----------------------------------------------------------------------
1578 // Deprecated functions
1579 
Open(std::unique_ptr<MessageReader> message_reader,std::shared_ptr<RecordBatchReader> * out)1580 Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_reader,
1581                                      std::shared_ptr<RecordBatchReader>* out) {
1582   return Open(std::move(message_reader), IpcReadOptions::Defaults()).Value(out);
1583 }
1584 
Open(std::unique_ptr<MessageReader> message_reader,std::unique_ptr<RecordBatchReader> * out)1585 Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_reader,
1586                                      std::unique_ptr<RecordBatchReader>* out) {
1587   auto result =
1588       std::unique_ptr<RecordBatchStreamReaderImpl>(new RecordBatchStreamReaderImpl());
1589   RETURN_NOT_OK(result->Open(std::move(message_reader), IpcReadOptions::Defaults()));
1590   *out = std::move(result);
1591   return Status::OK();
1592 }
1593 
Open(io::InputStream * stream,std::shared_ptr<RecordBatchReader> * out)1594 Status RecordBatchStreamReader::Open(io::InputStream* stream,
1595                                      std::shared_ptr<RecordBatchReader>* out) {
1596   return Open(MessageReader::Open(stream)).Value(out);
1597 }
1598 
Open(const std::shared_ptr<io::InputStream> & stream,std::shared_ptr<RecordBatchReader> * out)1599 Status RecordBatchStreamReader::Open(const std::shared_ptr<io::InputStream>& stream,
1600                                      std::shared_ptr<RecordBatchReader>* out) {
1601   return Open(MessageReader::Open(stream)).Value(out);
1602 }
1603 
Open(io::RandomAccessFile * file,std::shared_ptr<RecordBatchFileReader> * out)1604 Status RecordBatchFileReader::Open(io::RandomAccessFile* file,
1605                                    std::shared_ptr<RecordBatchFileReader>* out) {
1606   return Open(file).Value(out);
1607 }
1608 
Open(io::RandomAccessFile * file,int64_t footer_offset,std::shared_ptr<RecordBatchFileReader> * out)1609 Status RecordBatchFileReader::Open(io::RandomAccessFile* file, int64_t footer_offset,
1610                                    std::shared_ptr<RecordBatchFileReader>* out) {
1611   return Open(file, footer_offset).Value(out);
1612 }
1613 
Open(const std::shared_ptr<io::RandomAccessFile> & file,std::shared_ptr<RecordBatchFileReader> * out)1614 Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file,
1615                                    std::shared_ptr<RecordBatchFileReader>* out) {
1616   return Open(file).Value(out);
1617 }
1618 
Open(const std::shared_ptr<io::RandomAccessFile> & file,int64_t footer_offset,std::shared_ptr<RecordBatchFileReader> * out)1619 Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file,
1620                                    int64_t footer_offset,
1621                                    std::shared_ptr<RecordBatchFileReader>* out) {
1622   return Open(file, footer_offset).Value(out);
1623 }
1624 
ReadSchema(io::InputStream * stream,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * out)1625 Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo,
1626                   std::shared_ptr<Schema>* out) {
1627   return ReadSchema(stream, dictionary_memo).Value(out);
1628 }
1629 
ReadSchema(const Message & message,DictionaryMemo * dictionary_memo,std::shared_ptr<Schema> * out)1630 Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo,
1631                   std::shared_ptr<Schema>* out) {
1632   return ReadSchema(message, dictionary_memo).Value(out);
1633 }
1634 
ReadRecordBatch(const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,io::InputStream * stream,std::shared_ptr<RecordBatch> * out)1635 Status ReadRecordBatch(const std::shared_ptr<Schema>& schema,
1636                        const DictionaryMemo* dictionary_memo, io::InputStream* stream,
1637                        std::shared_ptr<RecordBatch>* out) {
1638   return ReadRecordBatch(schema, dictionary_memo, IpcReadOptions::Defaults(), stream)
1639       .Value(out);
1640 }
1641 
ReadRecordBatch(const Message & message,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,std::shared_ptr<RecordBatch> * out)1642 Status ReadRecordBatch(const Message& message, const std::shared_ptr<Schema>& schema,
1643                        const DictionaryMemo* dictionary_memo,
1644                        std::shared_ptr<RecordBatch>* out) {
1645   return ReadRecordBatch(message, schema, dictionary_memo, IpcReadOptions::Defaults())
1646       .Value(out);
1647 }
1648 
ReadRecordBatch(const Buffer & metadata,const std::shared_ptr<Schema> & schema,const DictionaryMemo * dictionary_memo,io::RandomAccessFile * file,std::shared_ptr<RecordBatch> * out)1649 Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr<Schema>& schema,
1650                        const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file,
1651                        std::shared_ptr<RecordBatch>* out) {
1652   return ReadRecordBatch(metadata, schema, dictionary_memo, IpcReadOptions::Defaults(),
1653                          file)
1654       .Value(out);
1655 }
1656 
1657 }  // namespace ipc
1658 }  // namespace arrow
1659