1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/ipc/message.h"
19 
20 #include <algorithm>
21 #include <cstddef>
22 #include <cstdint>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "arrow/buffer.h"
29 #include "arrow/device.h"
30 #include "arrow/io/interfaces.h"
31 #include "arrow/ipc/metadata_internal.h"
32 #include "arrow/ipc/options.h"
33 #include "arrow/ipc/util.h"
34 #include "arrow/status.h"
35 #include "arrow/util/endian.h"
36 #include "arrow/util/future.h"
37 #include "arrow/util/logging.h"
38 #include "arrow/util/ubsan.h"
39 
40 #include "generated/Message_generated.h"
41 
42 namespace arrow {
43 
44 class KeyValueMetadata;
45 class MemoryPool;
46 
47 namespace ipc {
48 
49 class Message::MessageImpl {
50  public:
MessageImpl(std::shared_ptr<Buffer> metadata,std::shared_ptr<Buffer> body)51   explicit MessageImpl(std::shared_ptr<Buffer> metadata, std::shared_ptr<Buffer> body)
52       : metadata_(std::move(metadata)), message_(nullptr), body_(std::move(body)) {}
53 
Open()54   Status Open() {
55     RETURN_NOT_OK(
56         internal::VerifyMessage(metadata_->data(), metadata_->size(), &message_));
57 
58     // Check that the metadata version is supported
59     if (message_->version() < internal::kMinMetadataVersion) {
60       return Status::Invalid("Old metadata version not supported");
61     }
62 
63     if (message_->version() > flatbuf::MetadataVersion::MAX) {
64       return Status::Invalid("Unsupported future MetadataVersion: ",
65                              static_cast<int16_t>(message_->version()));
66     }
67 
68     if (message_->custom_metadata() != nullptr) {
69       // Deserialize from Flatbuffers if first time called
70       std::shared_ptr<KeyValueMetadata> md;
71       RETURN_NOT_OK(internal::GetKeyValueMetadata(message_->custom_metadata(), &md));
72       custom_metadata_ = std::move(md);  // const-ify
73     }
74 
75     return Status::OK();
76   }
77 
type() const78   MessageType type() const {
79     switch (message_->header_type()) {
80       case flatbuf::MessageHeader::Schema:
81         return MessageType::SCHEMA;
82       case flatbuf::MessageHeader::DictionaryBatch:
83         return MessageType::DICTIONARY_BATCH;
84       case flatbuf::MessageHeader::RecordBatch:
85         return MessageType::RECORD_BATCH;
86       case flatbuf::MessageHeader::Tensor:
87         return MessageType::TENSOR;
88       case flatbuf::MessageHeader::SparseTensor:
89         return MessageType::SPARSE_TENSOR;
90       default:
91         return MessageType::NONE;
92     }
93   }
94 
version() const95   MetadataVersion version() const {
96     return internal::GetMetadataVersion(message_->version());
97   }
98 
header() const99   const void* header() const { return message_->header(); }
100 
body_length() const101   int64_t body_length() const { return message_->bodyLength(); }
102 
body() const103   std::shared_ptr<Buffer> body() const { return body_; }
104 
metadata() const105   std::shared_ptr<Buffer> metadata() const { return metadata_; }
106 
custom_metadata() const107   const std::shared_ptr<const KeyValueMetadata>& custom_metadata() const {
108     return custom_metadata_;
109   }
110 
111  private:
112   // The Flatbuffer metadata
113   std::shared_ptr<Buffer> metadata_;
114   const flatbuf::Message* message_;
115 
116   // The reconstructed custom_metadata field from the Message Flatbuffer
117   std::shared_ptr<const KeyValueMetadata> custom_metadata_;
118 
119   // The message body, if any
120   std::shared_ptr<Buffer> body_;
121 };
122 
Message(std::shared_ptr<Buffer> metadata,std::shared_ptr<Buffer> body)123 Message::Message(std::shared_ptr<Buffer> metadata, std::shared_ptr<Buffer> body) {
124   impl_.reset(new MessageImpl(std::move(metadata), std::move(body)));
125 }
126 
Open(std::shared_ptr<Buffer> metadata,std::shared_ptr<Buffer> body)127 Result<std::unique_ptr<Message>> Message::Open(std::shared_ptr<Buffer> metadata,
128                                                std::shared_ptr<Buffer> body) {
129   std::unique_ptr<Message> result(new Message(std::move(metadata), std::move(body)));
130   RETURN_NOT_OK(result->impl_->Open());
131   return std::move(result);
132 }
133 
~Message()134 Message::~Message() {}
135 
body() const136 std::shared_ptr<Buffer> Message::body() const { return impl_->body(); }
137 
body_length() const138 int64_t Message::body_length() const { return impl_->body_length(); }
139 
metadata() const140 std::shared_ptr<Buffer> Message::metadata() const { return impl_->metadata(); }
141 
type() const142 MessageType Message::type() const { return impl_->type(); }
143 
metadata_version() const144 MetadataVersion Message::metadata_version() const { return impl_->version(); }
145 
header() const146 const void* Message::header() const { return impl_->header(); }
147 
custom_metadata() const148 const std::shared_ptr<const KeyValueMetadata>& Message::custom_metadata() const {
149   return impl_->custom_metadata();
150 }
151 
Equals(const Message & other) const152 bool Message::Equals(const Message& other) const {
153   int64_t metadata_bytes = std::min(metadata()->size(), other.metadata()->size());
154 
155   if (!metadata()->Equals(*other.metadata(), metadata_bytes)) {
156     return false;
157   }
158 
159   // Compare bodies, if they have them
160   auto this_body = body();
161   auto other_body = other.body();
162 
163   const bool this_has_body = (this_body != nullptr) && (this_body->size() > 0);
164   const bool other_has_body = (other_body != nullptr) && (other_body->size() > 0);
165 
166   if (this_has_body && other_has_body) {
167     return this_body->Equals(*other_body);
168   } else if (this_has_body ^ other_has_body) {
169     // One has a body but not the other
170     return false;
171   } else {
172     // Neither has a body
173     return true;
174   }
175 }
176 
MaybeAlignMetadata(std::shared_ptr<Buffer> * metadata)177 Status MaybeAlignMetadata(std::shared_ptr<Buffer>* metadata) {
178   if (reinterpret_cast<uintptr_t>((*metadata)->data()) % 8 != 0) {
179     // If the metadata memory is not aligned, we copy it here to avoid
180     // potential UBSAN issues from Flatbuffers
181     ARROW_ASSIGN_OR_RAISE(*metadata, (*metadata)->CopySlice(0, (*metadata)->size()));
182   }
183   return Status::OK();
184 }
185 
CheckMetadataAndGetBodyLength(const Buffer & metadata,int64_t * body_length)186 Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_length) {
187   const flatbuf::Message* fb_message = nullptr;
188   RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &fb_message));
189   *body_length = fb_message->bodyLength();
190   if (*body_length < 0) {
191     return Status::IOError("Invalid IPC message: negative bodyLength");
192   }
193   return Status::OK();
194 }
195 
ReadFrom(std::shared_ptr<Buffer> metadata,io::InputStream * stream)196 Result<std::unique_ptr<Message>> Message::ReadFrom(std::shared_ptr<Buffer> metadata,
197                                                    io::InputStream* stream) {
198   std::unique_ptr<Message> result;
199   auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
200   MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size());
201   ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
202 
203   ARROW_ASSIGN_OR_RAISE(auto body, stream->Read(decoder.next_required_size()));
204   if (body->size() < decoder.next_required_size()) {
205     return Status::IOError("Expected to be able to read ", decoder.next_required_size(),
206                            " bytes for message body, got ", body->size());
207   }
208   RETURN_NOT_OK(decoder.Consume(body));
209   return std::move(result);
210 }
211 
ReadFrom(const int64_t offset,std::shared_ptr<Buffer> metadata,io::RandomAccessFile * file)212 Result<std::unique_ptr<Message>> Message::ReadFrom(const int64_t offset,
213                                                    std::shared_ptr<Buffer> metadata,
214                                                    io::RandomAccessFile* file) {
215   std::unique_ptr<Message> result;
216   auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
217   MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size());
218   ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
219 
220   ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset, decoder.next_required_size()));
221   if (body->size() < decoder.next_required_size()) {
222     return Status::IOError("Expected to be able to read ", decoder.next_required_size(),
223                            " bytes for message body, got ", body->size());
224   }
225   RETURN_NOT_OK(decoder.Consume(body));
226   return std::move(result);
227 }
228 
WritePadding(io::OutputStream * stream,int64_t nbytes)229 Status WritePadding(io::OutputStream* stream, int64_t nbytes) {
230   while (nbytes > 0) {
231     const int64_t bytes_to_write = std::min<int64_t>(nbytes, kArrowAlignment);
232     RETURN_NOT_OK(stream->Write(kPaddingBytes, bytes_to_write));
233     nbytes -= bytes_to_write;
234   }
235   return Status::OK();
236 }
237 
SerializeTo(io::OutputStream * stream,const IpcWriteOptions & options,int64_t * output_length) const238 Status Message::SerializeTo(io::OutputStream* stream, const IpcWriteOptions& options,
239                             int64_t* output_length) const {
240   int32_t metadata_length = 0;
241   RETURN_NOT_OK(WriteMessage(*metadata(), options, stream, &metadata_length));
242 
243   *output_length = metadata_length;
244 
245   auto body_buffer = body();
246   if (body_buffer) {
247     RETURN_NOT_OK(stream->Write(body_buffer));
248     *output_length += body_buffer->size();
249 
250     DCHECK_GE(this->body_length(), body_buffer->size());
251 
252     int64_t remainder = this->body_length() - body_buffer->size();
253     RETURN_NOT_OK(WritePadding(stream, remainder));
254     *output_length += remainder;
255   }
256   return Status::OK();
257 }
258 
Verify() const259 bool Message::Verify() const {
260   const flatbuf::Message* unused;
261   return internal::VerifyMessage(metadata()->data(), metadata()->size(), &unused).ok();
262 }
263 
FormatMessageType(MessageType type)264 std::string FormatMessageType(MessageType type) {
265   switch (type) {
266     case MessageType::SCHEMA:
267       return "schema";
268     case MessageType::RECORD_BATCH:
269       return "record batch";
270     case MessageType::DICTIONARY_BATCH:
271       return "dictionary";
272     case MessageType::TENSOR:
273       return "tensor";
274     case MessageType::SPARSE_TENSOR:
275       return "sparse tensor";
276     default:
277       break;
278   }
279   return "unknown";
280 }
281 
ReadMessage(int64_t offset,int32_t metadata_length,io::RandomAccessFile * file)282 Result<std::unique_ptr<Message>> ReadMessage(int64_t offset, int32_t metadata_length,
283                                              io::RandomAccessFile* file) {
284   std::unique_ptr<Message> result;
285   auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
286   MessageDecoder decoder(listener);
287 
288   if (metadata_length < decoder.next_required_size()) {
289     return Status::Invalid("metadata_length should be at least ",
290                            decoder.next_required_size());
291   }
292 
293   ARROW_ASSIGN_OR_RAISE(auto metadata, file->ReadAt(offset, metadata_length));
294   if (metadata->size() < metadata_length) {
295     return Status::Invalid("Expected to read ", metadata_length,
296                            " metadata bytes but got ", metadata->size());
297   }
298   ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
299 
300   switch (decoder.state()) {
301     case MessageDecoder::State::INITIAL:
302       return std::move(result);
303     case MessageDecoder::State::METADATA_LENGTH:
304       return Status::Invalid("metadata length is missing. File offset: ", offset,
305                              ", metadata length: ", metadata_length);
306     case MessageDecoder::State::METADATA:
307       return Status::Invalid("flatbuffer size ", decoder.next_required_size(),
308                              " invalid. File offset: ", offset,
309                              ", metadata length: ", metadata_length);
310     case MessageDecoder::State::BODY: {
311       ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset + metadata_length,
312                                                     decoder.next_required_size()));
313       if (body->size() < decoder.next_required_size()) {
314         return Status::IOError("Expected to be able to read ",
315                                decoder.next_required_size(),
316                                " bytes for message body, got ", body->size());
317       }
318       RETURN_NOT_OK(decoder.Consume(body));
319       return std::move(result);
320     }
321     case MessageDecoder::State::EOS:
322       return Status::Invalid("Unexpected empty message in IPC file format");
323     default:
324       return Status::Invalid("Unexpected state: ", decoder.state());
325   }
326 }
327 
ReadMessageAsync(int64_t offset,int32_t metadata_length,int64_t body_length,io::RandomAccessFile * file,const io::IOContext & context)328 Future<std::shared_ptr<Message>> ReadMessageAsync(int64_t offset, int32_t metadata_length,
329                                                   int64_t body_length,
330                                                   io::RandomAccessFile* file,
331                                                   const io::IOContext& context) {
332   struct State {
333     std::unique_ptr<Message> result;
334     std::shared_ptr<MessageDecoderListener> listener;
335     std::shared_ptr<MessageDecoder> decoder;
336   };
337   auto state = std::make_shared<State>();
338   state->listener = std::make_shared<AssignMessageDecoderListener>(&state->result);
339   state->decoder = std::make_shared<MessageDecoder>(state->listener);
340 
341   if (metadata_length < state->decoder->next_required_size()) {
342     return Status::Invalid("metadata_length should be at least ",
343                            state->decoder->next_required_size());
344   }
345   return file->ReadAsync(context, offset, metadata_length + body_length)
346       .Then([=](std::shared_ptr<Buffer> metadata) -> Result<std::shared_ptr<Message>> {
347         if (metadata->size() < metadata_length) {
348           return Status::Invalid("Expected to read ", metadata_length,
349                                  " metadata bytes but got ", metadata->size());
350         }
351         ARROW_RETURN_NOT_OK(
352             state->decoder->Consume(SliceBuffer(metadata, 0, metadata_length)));
353         switch (state->decoder->state()) {
354           case MessageDecoder::State::INITIAL:
355             return std::move(state->result);
356           case MessageDecoder::State::METADATA_LENGTH:
357             return Status::Invalid("metadata length is missing. File offset: ", offset,
358                                    ", metadata length: ", metadata_length);
359           case MessageDecoder::State::METADATA:
360             return Status::Invalid("flatbuffer size ",
361                                    state->decoder->next_required_size(),
362                                    " invalid. File offset: ", offset,
363                                    ", metadata length: ", metadata_length);
364           case MessageDecoder::State::BODY: {
365             auto body = SliceBuffer(metadata, metadata_length, body_length);
366             if (body->size() < state->decoder->next_required_size()) {
367               return Status::IOError("Expected to be able to read ",
368                                      state->decoder->next_required_size(),
369                                      " bytes for message body, got ", body->size());
370             }
371             RETURN_NOT_OK(state->decoder->Consume(body));
372             return std::move(state->result);
373           }
374           case MessageDecoder::State::EOS:
375             return Status::Invalid("Unexpected empty message in IPC file format");
376           default:
377             return Status::Invalid("Unexpected state: ", state->decoder->state());
378         }
379       });
380 }
381 
AlignStream(io::InputStream * stream,int32_t alignment)382 Status AlignStream(io::InputStream* stream, int32_t alignment) {
383   ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
384   return stream->Advance(PaddedLength(position, alignment) - position);
385 }
386 
AlignStream(io::OutputStream * stream,int32_t alignment)387 Status AlignStream(io::OutputStream* stream, int32_t alignment) {
388   ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
389   int64_t remainder = PaddedLength(position, alignment) - position;
390   if (remainder > 0) {
391     return stream->Write(kPaddingBytes, remainder);
392   }
393   return Status::OK();
394 }
395 
CheckAligned(io::FileInterface * stream,int32_t alignment)396 Status CheckAligned(io::FileInterface* stream, int32_t alignment) {
397   ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
398   if (position % alignment != 0) {
399     return Status::Invalid("Stream is not aligned pos: ", position,
400                            " alignment: ", alignment);
401   } else {
402     return Status::OK();
403   }
404 }
405 
DecodeMessage(MessageDecoder * decoder,io::InputStream * file)406 Status DecodeMessage(MessageDecoder* decoder, io::InputStream* file) {
407   if (decoder->state() == MessageDecoder::State::INITIAL) {
408     uint8_t continuation[sizeof(int32_t)];
409     ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, file->Read(sizeof(int32_t), &continuation));
410     if (bytes_read == 0) {
411       // EOS without indication
412       return Status::OK();
413     } else if (bytes_read != decoder->next_required_size()) {
414       return Status::Invalid("Corrupted message, only ", bytes_read, " bytes available");
415     }
416     ARROW_RETURN_NOT_OK(decoder->Consume(continuation, bytes_read));
417   }
418 
419   if (decoder->state() == MessageDecoder::State::METADATA_LENGTH) {
420     // Valid IPC message, read the message length now
421     uint8_t metadata_length[sizeof(int32_t)];
422     ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
423                           file->Read(sizeof(int32_t), &metadata_length));
424     if (bytes_read != decoder->next_required_size()) {
425       return Status::Invalid("Corrupted metadata length, only ", bytes_read,
426                              " bytes available");
427     }
428     ARROW_RETURN_NOT_OK(decoder->Consume(metadata_length, bytes_read));
429   }
430 
431   if (decoder->state() == MessageDecoder::State::EOS) {
432     return Status::OK();
433   }
434 
435   auto metadata_length = decoder->next_required_size();
436   ARROW_ASSIGN_OR_RAISE(auto metadata, file->Read(metadata_length));
437   if (metadata->size() != metadata_length) {
438     return Status::Invalid("Expected to read ", metadata_length, " metadata bytes, but ",
439                            "only read ", metadata->size());
440   }
441   ARROW_RETURN_NOT_OK(decoder->Consume(metadata));
442 
443   if (decoder->state() == MessageDecoder::State::BODY) {
444     ARROW_ASSIGN_OR_RAISE(auto body, file->Read(decoder->next_required_size()));
445     if (body->size() < decoder->next_required_size()) {
446       return Status::IOError("Expected to be able to read ",
447                              decoder->next_required_size(),
448                              " bytes for message body, got ", body->size());
449     }
450     ARROW_RETURN_NOT_OK(decoder->Consume(body));
451   }
452 
453   if (decoder->state() == MessageDecoder::State::INITIAL ||
454       decoder->state() == MessageDecoder::State::EOS) {
455     return Status::OK();
456   } else {
457     return Status::Invalid("Failed to decode message");
458   }
459 }
460 
ReadMessage(io::InputStream * file,MemoryPool * pool)461 Result<std::unique_ptr<Message>> ReadMessage(io::InputStream* file, MemoryPool* pool) {
462   std::unique_ptr<Message> message;
463   auto listener = std::make_shared<AssignMessageDecoderListener>(&message);
464   MessageDecoder decoder(listener, pool);
465   ARROW_RETURN_NOT_OK(DecodeMessage(&decoder, file));
466   if (!message) {
467     return nullptr;
468   } else {
469     return std::move(message);
470   }
471 }
472 
WriteMessage(const Buffer & message,const IpcWriteOptions & options,io::OutputStream * file,int32_t * message_length)473 Status WriteMessage(const Buffer& message, const IpcWriteOptions& options,
474                     io::OutputStream* file, int32_t* message_length) {
475   const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8;
476   const int32_t flatbuffer_size = static_cast<int32_t>(message.size());
477 
478   int32_t padded_message_length = static_cast<int32_t>(
479       PaddedLength(flatbuffer_size + prefix_size, options.alignment));
480 
481   int32_t padding = padded_message_length - flatbuffer_size - prefix_size;
482 
483   // The returned message size includes the length prefix, the flatbuffer,
484   // plus padding
485   *message_length = padded_message_length;
486 
487   // ARROW-6314: Write continuation / padding token
488   if (!options.write_legacy_ipc_format) {
489     RETURN_NOT_OK(file->Write(&internal::kIpcContinuationToken, sizeof(int32_t)));
490   }
491 
492   // Write the flatbuffer size prefix including padding in little endian
493   int32_t padded_flatbuffer_size =
494       BitUtil::ToLittleEndian(padded_message_length - prefix_size);
495   RETURN_NOT_OK(file->Write(&padded_flatbuffer_size, sizeof(int32_t)));
496 
497   // Write the flatbuffer
498   RETURN_NOT_OK(file->Write(message.data(), flatbuffer_size));
499   if (padding > 0) {
500     RETURN_NOT_OK(file->Write(kPaddingBytes, padding));
501   }
502 
503   return Status::OK();
504 }
505 
506 // ----------------------------------------------------------------------
507 // Implement MessageDecoder
508 
OnInitial()509 Status MessageDecoderListener::OnInitial() { return Status::OK(); }
OnMetadataLength()510 Status MessageDecoderListener::OnMetadataLength() { return Status::OK(); }
OnMetadata()511 Status MessageDecoderListener::OnMetadata() { return Status::OK(); }
OnBody()512 Status MessageDecoderListener::OnBody() { return Status::OK(); }
OnEOS()513 Status MessageDecoderListener::OnEOS() { return Status::OK(); }
514 
515 static constexpr auto kMessageDecoderNextRequiredSizeInitial = sizeof(int32_t);
516 static constexpr auto kMessageDecoderNextRequiredSizeMetadataLength = sizeof(int32_t);
517 
518 class MessageDecoder::MessageDecoderImpl {
519  public:
MessageDecoderImpl(std::shared_ptr<MessageDecoderListener> listener,State initial_state,int64_t initial_next_required_size,MemoryPool * pool)520   explicit MessageDecoderImpl(std::shared_ptr<MessageDecoderListener> listener,
521                               State initial_state, int64_t initial_next_required_size,
522                               MemoryPool* pool)
523       : listener_(std::move(listener)),
524         pool_(pool),
525         state_(initial_state),
526         next_required_size_(initial_next_required_size),
527         chunks_(),
528         buffered_size_(0),
529         metadata_(nullptr) {}
530 
ConsumeData(const uint8_t * data,int64_t size)531   Status ConsumeData(const uint8_t* data, int64_t size) {
532     if (buffered_size_ == 0) {
533       while (size > 0 && size >= next_required_size_) {
534         auto used_size = next_required_size_;
535         switch (state_) {
536           case State::INITIAL:
537             RETURN_NOT_OK(ConsumeInitialData(data, next_required_size_));
538             break;
539           case State::METADATA_LENGTH:
540             RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_));
541             break;
542           case State::METADATA: {
543             auto buffer = std::make_shared<Buffer>(data, next_required_size_);
544             RETURN_NOT_OK(ConsumeMetadataBuffer(buffer));
545           } break;
546           case State::BODY: {
547             auto buffer = std::make_shared<Buffer>(data, next_required_size_);
548             RETURN_NOT_OK(ConsumeBodyBuffer(buffer));
549           } break;
550           case State::EOS:
551             return Status::OK();
552         }
553         data += used_size;
554         size -= used_size;
555       }
556     }
557 
558     if (size == 0) {
559       return Status::OK();
560     }
561 
562     chunks_.push_back(std::make_shared<Buffer>(data, size));
563     buffered_size_ += size;
564     return ConsumeChunks();
565   }
566 
ConsumeBuffer(std::shared_ptr<Buffer> buffer)567   Status ConsumeBuffer(std::shared_ptr<Buffer> buffer) {
568     if (buffered_size_ == 0) {
569       while (buffer->size() >= next_required_size_) {
570         auto used_size = next_required_size_;
571         switch (state_) {
572           case State::INITIAL:
573             RETURN_NOT_OK(ConsumeInitialBuffer(buffer));
574             break;
575           case State::METADATA_LENGTH:
576             RETURN_NOT_OK(ConsumeMetadataLengthBuffer(buffer));
577             break;
578           case State::METADATA:
579             if (buffer->size() == next_required_size_) {
580               return ConsumeMetadataBuffer(buffer);
581             } else {
582               auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
583               RETURN_NOT_OK(ConsumeMetadataBuffer(sliced_buffer));
584             }
585             break;
586           case State::BODY:
587             if (buffer->size() == next_required_size_) {
588               return ConsumeBodyBuffer(buffer);
589             } else {
590               auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
591               RETURN_NOT_OK(ConsumeBodyBuffer(sliced_buffer));
592             }
593             break;
594           case State::EOS:
595             return Status::OK();
596         }
597         if (buffer->size() == used_size) {
598           return Status::OK();
599         }
600         buffer = SliceBuffer(buffer, used_size);
601       }
602     }
603 
604     if (buffer->size() == 0) {
605       return Status::OK();
606     }
607 
608     buffered_size_ += buffer->size();
609     chunks_.push_back(std::move(buffer));
610     return ConsumeChunks();
611   }
612 
next_required_size() const613   int64_t next_required_size() const { return next_required_size_ - buffered_size_; }
614 
state() const615   MessageDecoder::State state() const { return state_; }
616 
617  private:
ConsumeChunks()618   Status ConsumeChunks() {
619     while (state_ != State::EOS) {
620       if (buffered_size_ < next_required_size_) {
621         return Status::OK();
622       }
623 
624       switch (state_) {
625         case State::INITIAL:
626           RETURN_NOT_OK(ConsumeInitialChunks());
627           break;
628         case State::METADATA_LENGTH:
629           RETURN_NOT_OK(ConsumeMetadataLengthChunks());
630           break;
631         case State::METADATA:
632           RETURN_NOT_OK(ConsumeMetadataChunks());
633           break;
634         case State::BODY:
635           RETURN_NOT_OK(ConsumeBodyChunks());
636           break;
637         case State::EOS:
638           return Status::OK();
639       }
640     }
641 
642     return Status::OK();
643   }
644 
ConsumeInitialData(const uint8_t * data,int64_t size)645   Status ConsumeInitialData(const uint8_t* data, int64_t size) {
646     return ConsumeInitial(BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
647   }
648 
ConsumeInitialBuffer(const std::shared_ptr<Buffer> & buffer)649   Status ConsumeInitialBuffer(const std::shared_ptr<Buffer>& buffer) {
650     ARROW_ASSIGN_OR_RAISE(auto continuation, ConsumeDataBufferInt32(buffer));
651     return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
652   }
653 
ConsumeInitialChunks()654   Status ConsumeInitialChunks() {
655     int32_t continuation = 0;
656     RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &continuation));
657     return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
658   }
659 
ConsumeInitial(int32_t continuation)660   Status ConsumeInitial(int32_t continuation) {
661     if (continuation == internal::kIpcContinuationToken) {
662       state_ = State::METADATA_LENGTH;
663       next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength;
664       RETURN_NOT_OK(listener_->OnMetadataLength());
665       // Valid IPC message, read the message length now
666       return Status::OK();
667     } else if (continuation == 0) {
668       state_ = State::EOS;
669       next_required_size_ = 0;
670       RETURN_NOT_OK(listener_->OnEOS());
671       return Status::OK();
672     } else if (continuation > 0) {
673       state_ = State::METADATA;
674       // ARROW-6314: Backwards compatibility for reading old IPC
675       // messages produced prior to version 0.15.0
676       next_required_size_ = continuation;
677       RETURN_NOT_OK(listener_->OnMetadata());
678       return Status::OK();
679     } else {
680       return Status::IOError("Invalid IPC stream: negative continuation token");
681     }
682   }
683 
ConsumeMetadataLengthData(const uint8_t * data,int64_t size)684   Status ConsumeMetadataLengthData(const uint8_t* data, int64_t size) {
685     return ConsumeMetadataLength(
686         BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
687   }
688 
ConsumeMetadataLengthBuffer(const std::shared_ptr<Buffer> & buffer)689   Status ConsumeMetadataLengthBuffer(const std::shared_ptr<Buffer>& buffer) {
690     ARROW_ASSIGN_OR_RAISE(auto metadata_length, ConsumeDataBufferInt32(buffer));
691     return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
692   }
693 
ConsumeMetadataLengthChunks()694   Status ConsumeMetadataLengthChunks() {
695     int32_t metadata_length = 0;
696     RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &metadata_length));
697     return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
698   }
699 
ConsumeMetadataLength(int32_t metadata_length)700   Status ConsumeMetadataLength(int32_t metadata_length) {
701     if (metadata_length == 0) {
702       state_ = State::EOS;
703       next_required_size_ = 0;
704       RETURN_NOT_OK(listener_->OnEOS());
705       return Status::OK();
706     } else if (metadata_length > 0) {
707       state_ = State::METADATA;
708       next_required_size_ = metadata_length;
709       RETURN_NOT_OK(listener_->OnMetadata());
710       return Status::OK();
711     } else {
712       return Status::IOError("Invalid IPC message: negative metadata length");
713     }
714   }
715 
ConsumeMetadataBuffer(const std::shared_ptr<Buffer> & buffer)716   Status ConsumeMetadataBuffer(const std::shared_ptr<Buffer>& buffer) {
717     if (buffer->is_cpu()) {
718       metadata_ = buffer;
719     } else {
720       ARROW_ASSIGN_OR_RAISE(metadata_,
721                             Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
722     }
723     return ConsumeMetadata();
724   }
725 
ConsumeMetadataChunks()726   Status ConsumeMetadataChunks() {
727     if (chunks_[0]->size() >= next_required_size_) {
728       if (chunks_[0]->size() == next_required_size_) {
729         if (chunks_[0]->is_cpu()) {
730           metadata_ = std::move(chunks_[0]);
731         } else {
732           ARROW_ASSIGN_OR_RAISE(
733               metadata_,
734               Buffer::ViewOrCopy(chunks_[0], CPUDevice::memory_manager(pool_)));
735         }
736         chunks_.erase(chunks_.begin());
737       } else {
738         metadata_ = SliceBuffer(chunks_[0], 0, next_required_size_);
739         if (!chunks_[0]->is_cpu()) {
740           ARROW_ASSIGN_OR_RAISE(
741               metadata_, Buffer::ViewOrCopy(metadata_, CPUDevice::memory_manager(pool_)));
742         }
743         chunks_[0] = SliceBuffer(chunks_[0], next_required_size_);
744       }
745       buffered_size_ -= next_required_size_;
746     } else {
747       ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_));
748       metadata_ = std::shared_ptr<Buffer>(metadata.release());
749       RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data()));
750     }
751     return ConsumeMetadata();
752   }
753 
ConsumeMetadata()754   Status ConsumeMetadata() {
755     RETURN_NOT_OK(MaybeAlignMetadata(&metadata_));
756     int64_t body_length = -1;
757     RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata_, &body_length));
758 
759     state_ = State::BODY;
760     next_required_size_ = body_length;
761     RETURN_NOT_OK(listener_->OnBody());
762     if (next_required_size_ == 0) {
763       ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_));
764       std::shared_ptr<Buffer> shared_body(body.release());
765       return ConsumeBody(&shared_body);
766     } else {
767       return Status::OK();
768     }
769   }
770 
ConsumeBodyBuffer(std::shared_ptr<Buffer> buffer)771   Status ConsumeBodyBuffer(std::shared_ptr<Buffer> buffer) {
772     return ConsumeBody(&buffer);
773   }
774 
ConsumeBodyChunks()775   Status ConsumeBodyChunks() {
776     if (chunks_[0]->size() >= next_required_size_) {
777       auto used_size = next_required_size_;
778       if (chunks_[0]->size() == next_required_size_) {
779         RETURN_NOT_OK(ConsumeBody(&chunks_[0]));
780         chunks_.erase(chunks_.begin());
781       } else {
782         auto body = SliceBuffer(chunks_[0], 0, next_required_size_);
783         RETURN_NOT_OK(ConsumeBody(&body));
784         chunks_[0] = SliceBuffer(chunks_[0], used_size);
785       }
786       buffered_size_ -= used_size;
787       return Status::OK();
788     } else {
789       ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_));
790       RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data()));
791       std::shared_ptr<Buffer> shared_body(body.release());
792       return ConsumeBody(&shared_body);
793     }
794   }
795 
ConsumeBody(std::shared_ptr<Buffer> * buffer)796   Status ConsumeBody(std::shared_ptr<Buffer>* buffer) {
797     ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
798                           Message::Open(metadata_, *buffer));
799 
800     RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message)));
801     state_ = State::INITIAL;
802     next_required_size_ = kMessageDecoderNextRequiredSizeInitial;
803     RETURN_NOT_OK(listener_->OnInitial());
804     return Status::OK();
805   }
806 
ConsumeDataBufferInt32(const std::shared_ptr<Buffer> & buffer)807   Result<int32_t> ConsumeDataBufferInt32(const std::shared_ptr<Buffer>& buffer) {
808     if (buffer->is_cpu()) {
809       return util::SafeLoadAs<int32_t>(buffer->data());
810     } else {
811       ARROW_ASSIGN_OR_RAISE(auto cpu_buffer,
812                             Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
813       return util::SafeLoadAs<int32_t>(cpu_buffer->data());
814     }
815   }
816 
ConsumeDataChunks(int64_t nbytes,void * out)817   Status ConsumeDataChunks(int64_t nbytes, void* out) {
818     size_t offset = 0;
819     size_t n_used_chunks = 0;
820     auto required_size = nbytes;
821     std::shared_ptr<Buffer> last_chunk;
822     for (auto& chunk : chunks_) {
823       if (!chunk->is_cpu()) {
824         ARROW_ASSIGN_OR_RAISE(
825             chunk, Buffer::ViewOrCopy(chunk, CPUDevice::memory_manager(pool_)));
826       }
827       auto data = chunk->data();
828       auto data_size = chunk->size();
829       auto copy_size = std::min(required_size, data_size);
830       memcpy(static_cast<uint8_t*>(out) + offset, data, copy_size);
831       n_used_chunks++;
832       offset += copy_size;
833       required_size -= copy_size;
834       if (required_size == 0) {
835         if (data_size != copy_size) {
836           last_chunk = SliceBuffer(chunk, copy_size);
837         }
838         break;
839       }
840     }
841     chunks_.erase(chunks_.begin(), chunks_.begin() + n_used_chunks);
842     if (last_chunk.get() != nullptr) {
843       chunks_.insert(chunks_.begin(), std::move(last_chunk));
844     }
845     buffered_size_ -= offset;
846     return Status::OK();
847   }
848 
849   std::shared_ptr<MessageDecoderListener> listener_;
850   MemoryPool* pool_;
851   State state_;
852   int64_t next_required_size_;
853   std::vector<std::shared_ptr<Buffer>> chunks_;
854   int64_t buffered_size_;
855   std::shared_ptr<Buffer> metadata_;  // Must be CPU buffer
856 };
857 
MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,MemoryPool * pool)858 MessageDecoder::MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,
859                                MemoryPool* pool) {
860   impl_.reset(new MessageDecoderImpl(std::move(listener), State::INITIAL,
861                                      kMessageDecoderNextRequiredSizeInitial, pool));
862 }
863 
MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,State initial_state,int64_t initial_next_required_size,MemoryPool * pool)864 MessageDecoder::MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,
865                                State initial_state, int64_t initial_next_required_size,
866                                MemoryPool* pool) {
867   impl_.reset(new MessageDecoderImpl(std::move(listener), initial_state,
868                                      initial_next_required_size, pool));
869 }
870 
~MessageDecoder()871 MessageDecoder::~MessageDecoder() {}
872 
Consume(const uint8_t * data,int64_t size)873 Status MessageDecoder::Consume(const uint8_t* data, int64_t size) {
874   return impl_->ConsumeData(data, size);
875 }
876 
Consume(std::shared_ptr<Buffer> buffer)877 Status MessageDecoder::Consume(std::shared_ptr<Buffer> buffer) {
878   return impl_->ConsumeBuffer(buffer);
879 }
880 
next_required_size() const881 int64_t MessageDecoder::next_required_size() const { return impl_->next_required_size(); }
882 
state() const883 MessageDecoder::State MessageDecoder::state() const { return impl_->state(); }
884 
885 // ----------------------------------------------------------------------
886 // Implement InputStream message reader
887 
888 /// \brief Implementation of MessageReader that reads from InputStream
889 class InputStreamMessageReader : public MessageReader, public MessageDecoderListener {
890  public:
InputStreamMessageReader(io::InputStream * stream)891   explicit InputStreamMessageReader(io::InputStream* stream)
892       : stream_(stream),
893         owned_stream_(),
894         message_(),
895         decoder_(std::shared_ptr<InputStreamMessageReader>(this, [](void*) {})) {}
896 
InputStreamMessageReader(const std::shared_ptr<io::InputStream> & owned_stream)897   explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream)
898       : InputStreamMessageReader(owned_stream.get()) {
899     owned_stream_ = owned_stream;
900   }
901 
~InputStreamMessageReader()902   ~InputStreamMessageReader() {}
903 
OnMessageDecoded(std::unique_ptr<Message> message)904   Status OnMessageDecoded(std::unique_ptr<Message> message) override {
905     message_ = std::move(message);
906     return Status::OK();
907   }
908 
ReadNextMessage()909   Result<std::unique_ptr<Message>> ReadNextMessage() override {
910     ARROW_RETURN_NOT_OK(DecodeMessage(&decoder_, stream_));
911     return std::move(message_);
912   }
913 
914  private:
915   io::InputStream* stream_;
916   std::shared_ptr<io::InputStream> owned_stream_;
917   std::unique_ptr<Message> message_;
918   MessageDecoder decoder_;
919 };
920 
Open(io::InputStream * stream)921 std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) {
922   return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream));
923 }
924 
Open(const std::shared_ptr<io::InputStream> & owned_stream)925 std::unique_ptr<MessageReader> MessageReader::Open(
926     const std::shared_ptr<io::InputStream>& owned_stream) {
927   return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream));
928 }
929 
930 }  // namespace ipc
931 }  // namespace arrow
932