1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/ipc/message.h"
19
20 #include <algorithm>
21 #include <cstddef>
22 #include <cstdint>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "arrow/buffer.h"
29 #include "arrow/device.h"
30 #include "arrow/io/interfaces.h"
31 #include "arrow/ipc/metadata_internal.h"
32 #include "arrow/ipc/options.h"
33 #include "arrow/ipc/util.h"
34 #include "arrow/status.h"
35 #include "arrow/util/endian.h"
36 #include "arrow/util/future.h"
37 #include "arrow/util/logging.h"
38 #include "arrow/util/ubsan.h"
39
40 #include "generated/Message_generated.h"
41
42 namespace arrow {
43
44 class KeyValueMetadata;
45 class MemoryPool;
46
47 namespace ipc {
48
49 class Message::MessageImpl {
50 public:
MessageImpl(std::shared_ptr<Buffer> metadata,std::shared_ptr<Buffer> body)51 explicit MessageImpl(std::shared_ptr<Buffer> metadata, std::shared_ptr<Buffer> body)
52 : metadata_(std::move(metadata)), message_(nullptr), body_(std::move(body)) {}
53
Open()54 Status Open() {
55 RETURN_NOT_OK(
56 internal::VerifyMessage(metadata_->data(), metadata_->size(), &message_));
57
58 // Check that the metadata version is supported
59 if (message_->version() < internal::kMinMetadataVersion) {
60 return Status::Invalid("Old metadata version not supported");
61 }
62
63 if (message_->version() > flatbuf::MetadataVersion::MAX) {
64 return Status::Invalid("Unsupported future MetadataVersion: ",
65 static_cast<int16_t>(message_->version()));
66 }
67
68 if (message_->custom_metadata() != nullptr) {
69 // Deserialize from Flatbuffers if first time called
70 std::shared_ptr<KeyValueMetadata> md;
71 RETURN_NOT_OK(internal::GetKeyValueMetadata(message_->custom_metadata(), &md));
72 custom_metadata_ = std::move(md); // const-ify
73 }
74
75 return Status::OK();
76 }
77
type() const78 MessageType type() const {
79 switch (message_->header_type()) {
80 case flatbuf::MessageHeader::Schema:
81 return MessageType::SCHEMA;
82 case flatbuf::MessageHeader::DictionaryBatch:
83 return MessageType::DICTIONARY_BATCH;
84 case flatbuf::MessageHeader::RecordBatch:
85 return MessageType::RECORD_BATCH;
86 case flatbuf::MessageHeader::Tensor:
87 return MessageType::TENSOR;
88 case flatbuf::MessageHeader::SparseTensor:
89 return MessageType::SPARSE_TENSOR;
90 default:
91 return MessageType::NONE;
92 }
93 }
94
version() const95 MetadataVersion version() const {
96 return internal::GetMetadataVersion(message_->version());
97 }
98
header() const99 const void* header() const { return message_->header(); }
100
body_length() const101 int64_t body_length() const { return message_->bodyLength(); }
102
body() const103 std::shared_ptr<Buffer> body() const { return body_; }
104
metadata() const105 std::shared_ptr<Buffer> metadata() const { return metadata_; }
106
custom_metadata() const107 const std::shared_ptr<const KeyValueMetadata>& custom_metadata() const {
108 return custom_metadata_;
109 }
110
111 private:
112 // The Flatbuffer metadata
113 std::shared_ptr<Buffer> metadata_;
114 const flatbuf::Message* message_;
115
116 // The reconstructed custom_metadata field from the Message Flatbuffer
117 std::shared_ptr<const KeyValueMetadata> custom_metadata_;
118
119 // The message body, if any
120 std::shared_ptr<Buffer> body_;
121 };
122
Message(std::shared_ptr<Buffer> metadata,std::shared_ptr<Buffer> body)123 Message::Message(std::shared_ptr<Buffer> metadata, std::shared_ptr<Buffer> body) {
124 impl_.reset(new MessageImpl(std::move(metadata), std::move(body)));
125 }
126
Open(std::shared_ptr<Buffer> metadata,std::shared_ptr<Buffer> body)127 Result<std::unique_ptr<Message>> Message::Open(std::shared_ptr<Buffer> metadata,
128 std::shared_ptr<Buffer> body) {
129 std::unique_ptr<Message> result(new Message(std::move(metadata), std::move(body)));
130 RETURN_NOT_OK(result->impl_->Open());
131 return std::move(result);
132 }
133
~Message()134 Message::~Message() {}
135
body() const136 std::shared_ptr<Buffer> Message::body() const { return impl_->body(); }
137
body_length() const138 int64_t Message::body_length() const { return impl_->body_length(); }
139
metadata() const140 std::shared_ptr<Buffer> Message::metadata() const { return impl_->metadata(); }
141
type() const142 MessageType Message::type() const { return impl_->type(); }
143
metadata_version() const144 MetadataVersion Message::metadata_version() const { return impl_->version(); }
145
header() const146 const void* Message::header() const { return impl_->header(); }
147
custom_metadata() const148 const std::shared_ptr<const KeyValueMetadata>& Message::custom_metadata() const {
149 return impl_->custom_metadata();
150 }
151
Equals(const Message & other) const152 bool Message::Equals(const Message& other) const {
153 int64_t metadata_bytes = std::min(metadata()->size(), other.metadata()->size());
154
155 if (!metadata()->Equals(*other.metadata(), metadata_bytes)) {
156 return false;
157 }
158
159 // Compare bodies, if they have them
160 auto this_body = body();
161 auto other_body = other.body();
162
163 const bool this_has_body = (this_body != nullptr) && (this_body->size() > 0);
164 const bool other_has_body = (other_body != nullptr) && (other_body->size() > 0);
165
166 if (this_has_body && other_has_body) {
167 return this_body->Equals(*other_body);
168 } else if (this_has_body ^ other_has_body) {
169 // One has a body but not the other
170 return false;
171 } else {
172 // Neither has a body
173 return true;
174 }
175 }
176
MaybeAlignMetadata(std::shared_ptr<Buffer> * metadata)177 Status MaybeAlignMetadata(std::shared_ptr<Buffer>* metadata) {
178 if (reinterpret_cast<uintptr_t>((*metadata)->data()) % 8 != 0) {
179 // If the metadata memory is not aligned, we copy it here to avoid
180 // potential UBSAN issues from Flatbuffers
181 ARROW_ASSIGN_OR_RAISE(*metadata, (*metadata)->CopySlice(0, (*metadata)->size()));
182 }
183 return Status::OK();
184 }
185
CheckMetadataAndGetBodyLength(const Buffer & metadata,int64_t * body_length)186 Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_length) {
187 const flatbuf::Message* fb_message = nullptr;
188 RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &fb_message));
189 *body_length = fb_message->bodyLength();
190 if (*body_length < 0) {
191 return Status::IOError("Invalid IPC message: negative bodyLength");
192 }
193 return Status::OK();
194 }
195
ReadFrom(std::shared_ptr<Buffer> metadata,io::InputStream * stream)196 Result<std::unique_ptr<Message>> Message::ReadFrom(std::shared_ptr<Buffer> metadata,
197 io::InputStream* stream) {
198 std::unique_ptr<Message> result;
199 auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
200 MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size());
201 ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
202
203 ARROW_ASSIGN_OR_RAISE(auto body, stream->Read(decoder.next_required_size()));
204 if (body->size() < decoder.next_required_size()) {
205 return Status::IOError("Expected to be able to read ", decoder.next_required_size(),
206 " bytes for message body, got ", body->size());
207 }
208 RETURN_NOT_OK(decoder.Consume(body));
209 return std::move(result);
210 }
211
ReadFrom(const int64_t offset,std::shared_ptr<Buffer> metadata,io::RandomAccessFile * file)212 Result<std::unique_ptr<Message>> Message::ReadFrom(const int64_t offset,
213 std::shared_ptr<Buffer> metadata,
214 io::RandomAccessFile* file) {
215 std::unique_ptr<Message> result;
216 auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
217 MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size());
218 ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
219
220 ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset, decoder.next_required_size()));
221 if (body->size() < decoder.next_required_size()) {
222 return Status::IOError("Expected to be able to read ", decoder.next_required_size(),
223 " bytes for message body, got ", body->size());
224 }
225 RETURN_NOT_OK(decoder.Consume(body));
226 return std::move(result);
227 }
228
WritePadding(io::OutputStream * stream,int64_t nbytes)229 Status WritePadding(io::OutputStream* stream, int64_t nbytes) {
230 while (nbytes > 0) {
231 const int64_t bytes_to_write = std::min<int64_t>(nbytes, kArrowAlignment);
232 RETURN_NOT_OK(stream->Write(kPaddingBytes, bytes_to_write));
233 nbytes -= bytes_to_write;
234 }
235 return Status::OK();
236 }
237
SerializeTo(io::OutputStream * stream,const IpcWriteOptions & options,int64_t * output_length) const238 Status Message::SerializeTo(io::OutputStream* stream, const IpcWriteOptions& options,
239 int64_t* output_length) const {
240 int32_t metadata_length = 0;
241 RETURN_NOT_OK(WriteMessage(*metadata(), options, stream, &metadata_length));
242
243 *output_length = metadata_length;
244
245 auto body_buffer = body();
246 if (body_buffer) {
247 RETURN_NOT_OK(stream->Write(body_buffer));
248 *output_length += body_buffer->size();
249
250 DCHECK_GE(this->body_length(), body_buffer->size());
251
252 int64_t remainder = this->body_length() - body_buffer->size();
253 RETURN_NOT_OK(WritePadding(stream, remainder));
254 *output_length += remainder;
255 }
256 return Status::OK();
257 }
258
Verify() const259 bool Message::Verify() const {
260 const flatbuf::Message* unused;
261 return internal::VerifyMessage(metadata()->data(), metadata()->size(), &unused).ok();
262 }
263
FormatMessageType(MessageType type)264 std::string FormatMessageType(MessageType type) {
265 switch (type) {
266 case MessageType::SCHEMA:
267 return "schema";
268 case MessageType::RECORD_BATCH:
269 return "record batch";
270 case MessageType::DICTIONARY_BATCH:
271 return "dictionary";
272 case MessageType::TENSOR:
273 return "tensor";
274 case MessageType::SPARSE_TENSOR:
275 return "sparse tensor";
276 default:
277 break;
278 }
279 return "unknown";
280 }
281
ReadMessage(int64_t offset,int32_t metadata_length,io::RandomAccessFile * file)282 Result<std::unique_ptr<Message>> ReadMessage(int64_t offset, int32_t metadata_length,
283 io::RandomAccessFile* file) {
284 std::unique_ptr<Message> result;
285 auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
286 MessageDecoder decoder(listener);
287
288 if (metadata_length < decoder.next_required_size()) {
289 return Status::Invalid("metadata_length should be at least ",
290 decoder.next_required_size());
291 }
292
293 ARROW_ASSIGN_OR_RAISE(auto metadata, file->ReadAt(offset, metadata_length));
294 if (metadata->size() < metadata_length) {
295 return Status::Invalid("Expected to read ", metadata_length,
296 " metadata bytes but got ", metadata->size());
297 }
298 ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
299
300 switch (decoder.state()) {
301 case MessageDecoder::State::INITIAL:
302 return std::move(result);
303 case MessageDecoder::State::METADATA_LENGTH:
304 return Status::Invalid("metadata length is missing. File offset: ", offset,
305 ", metadata length: ", metadata_length);
306 case MessageDecoder::State::METADATA:
307 return Status::Invalid("flatbuffer size ", decoder.next_required_size(),
308 " invalid. File offset: ", offset,
309 ", metadata length: ", metadata_length);
310 case MessageDecoder::State::BODY: {
311 ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset + metadata_length,
312 decoder.next_required_size()));
313 if (body->size() < decoder.next_required_size()) {
314 return Status::IOError("Expected to be able to read ",
315 decoder.next_required_size(),
316 " bytes for message body, got ", body->size());
317 }
318 RETURN_NOT_OK(decoder.Consume(body));
319 return std::move(result);
320 }
321 case MessageDecoder::State::EOS:
322 return Status::Invalid("Unexpected empty message in IPC file format");
323 default:
324 return Status::Invalid("Unexpected state: ", decoder.state());
325 }
326 }
327
ReadMessageAsync(int64_t offset,int32_t metadata_length,int64_t body_length,io::RandomAccessFile * file,const io::IOContext & context)328 Future<std::shared_ptr<Message>> ReadMessageAsync(int64_t offset, int32_t metadata_length,
329 int64_t body_length,
330 io::RandomAccessFile* file,
331 const io::IOContext& context) {
332 struct State {
333 std::unique_ptr<Message> result;
334 std::shared_ptr<MessageDecoderListener> listener;
335 std::shared_ptr<MessageDecoder> decoder;
336 };
337 auto state = std::make_shared<State>();
338 state->listener = std::make_shared<AssignMessageDecoderListener>(&state->result);
339 state->decoder = std::make_shared<MessageDecoder>(state->listener);
340
341 if (metadata_length < state->decoder->next_required_size()) {
342 return Status::Invalid("metadata_length should be at least ",
343 state->decoder->next_required_size());
344 }
345 return file->ReadAsync(context, offset, metadata_length + body_length)
346 .Then([=](std::shared_ptr<Buffer> metadata) -> Result<std::shared_ptr<Message>> {
347 if (metadata->size() < metadata_length) {
348 return Status::Invalid("Expected to read ", metadata_length,
349 " metadata bytes but got ", metadata->size());
350 }
351 ARROW_RETURN_NOT_OK(
352 state->decoder->Consume(SliceBuffer(metadata, 0, metadata_length)));
353 switch (state->decoder->state()) {
354 case MessageDecoder::State::INITIAL:
355 return std::move(state->result);
356 case MessageDecoder::State::METADATA_LENGTH:
357 return Status::Invalid("metadata length is missing. File offset: ", offset,
358 ", metadata length: ", metadata_length);
359 case MessageDecoder::State::METADATA:
360 return Status::Invalid("flatbuffer size ",
361 state->decoder->next_required_size(),
362 " invalid. File offset: ", offset,
363 ", metadata length: ", metadata_length);
364 case MessageDecoder::State::BODY: {
365 auto body = SliceBuffer(metadata, metadata_length, body_length);
366 if (body->size() < state->decoder->next_required_size()) {
367 return Status::IOError("Expected to be able to read ",
368 state->decoder->next_required_size(),
369 " bytes for message body, got ", body->size());
370 }
371 RETURN_NOT_OK(state->decoder->Consume(body));
372 return std::move(state->result);
373 }
374 case MessageDecoder::State::EOS:
375 return Status::Invalid("Unexpected empty message in IPC file format");
376 default:
377 return Status::Invalid("Unexpected state: ", state->decoder->state());
378 }
379 });
380 }
381
AlignStream(io::InputStream * stream,int32_t alignment)382 Status AlignStream(io::InputStream* stream, int32_t alignment) {
383 ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
384 return stream->Advance(PaddedLength(position, alignment) - position);
385 }
386
AlignStream(io::OutputStream * stream,int32_t alignment)387 Status AlignStream(io::OutputStream* stream, int32_t alignment) {
388 ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
389 int64_t remainder = PaddedLength(position, alignment) - position;
390 if (remainder > 0) {
391 return stream->Write(kPaddingBytes, remainder);
392 }
393 return Status::OK();
394 }
395
CheckAligned(io::FileInterface * stream,int32_t alignment)396 Status CheckAligned(io::FileInterface* stream, int32_t alignment) {
397 ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
398 if (position % alignment != 0) {
399 return Status::Invalid("Stream is not aligned pos: ", position,
400 " alignment: ", alignment);
401 } else {
402 return Status::OK();
403 }
404 }
405
DecodeMessage(MessageDecoder * decoder,io::InputStream * file)406 Status DecodeMessage(MessageDecoder* decoder, io::InputStream* file) {
407 if (decoder->state() == MessageDecoder::State::INITIAL) {
408 uint8_t continuation[sizeof(int32_t)];
409 ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, file->Read(sizeof(int32_t), &continuation));
410 if (bytes_read == 0) {
411 // EOS without indication
412 return Status::OK();
413 } else if (bytes_read != decoder->next_required_size()) {
414 return Status::Invalid("Corrupted message, only ", bytes_read, " bytes available");
415 }
416 ARROW_RETURN_NOT_OK(decoder->Consume(continuation, bytes_read));
417 }
418
419 if (decoder->state() == MessageDecoder::State::METADATA_LENGTH) {
420 // Valid IPC message, read the message length now
421 uint8_t metadata_length[sizeof(int32_t)];
422 ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
423 file->Read(sizeof(int32_t), &metadata_length));
424 if (bytes_read != decoder->next_required_size()) {
425 return Status::Invalid("Corrupted metadata length, only ", bytes_read,
426 " bytes available");
427 }
428 ARROW_RETURN_NOT_OK(decoder->Consume(metadata_length, bytes_read));
429 }
430
431 if (decoder->state() == MessageDecoder::State::EOS) {
432 return Status::OK();
433 }
434
435 auto metadata_length = decoder->next_required_size();
436 ARROW_ASSIGN_OR_RAISE(auto metadata, file->Read(metadata_length));
437 if (metadata->size() != metadata_length) {
438 return Status::Invalid("Expected to read ", metadata_length, " metadata bytes, but ",
439 "only read ", metadata->size());
440 }
441 ARROW_RETURN_NOT_OK(decoder->Consume(metadata));
442
443 if (decoder->state() == MessageDecoder::State::BODY) {
444 ARROW_ASSIGN_OR_RAISE(auto body, file->Read(decoder->next_required_size()));
445 if (body->size() < decoder->next_required_size()) {
446 return Status::IOError("Expected to be able to read ",
447 decoder->next_required_size(),
448 " bytes for message body, got ", body->size());
449 }
450 ARROW_RETURN_NOT_OK(decoder->Consume(body));
451 }
452
453 if (decoder->state() == MessageDecoder::State::INITIAL ||
454 decoder->state() == MessageDecoder::State::EOS) {
455 return Status::OK();
456 } else {
457 return Status::Invalid("Failed to decode message");
458 }
459 }
460
ReadMessage(io::InputStream * file,MemoryPool * pool)461 Result<std::unique_ptr<Message>> ReadMessage(io::InputStream* file, MemoryPool* pool) {
462 std::unique_ptr<Message> message;
463 auto listener = std::make_shared<AssignMessageDecoderListener>(&message);
464 MessageDecoder decoder(listener, pool);
465 ARROW_RETURN_NOT_OK(DecodeMessage(&decoder, file));
466 if (!message) {
467 return nullptr;
468 } else {
469 return std::move(message);
470 }
471 }
472
WriteMessage(const Buffer & message,const IpcWriteOptions & options,io::OutputStream * file,int32_t * message_length)473 Status WriteMessage(const Buffer& message, const IpcWriteOptions& options,
474 io::OutputStream* file, int32_t* message_length) {
475 const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8;
476 const int32_t flatbuffer_size = static_cast<int32_t>(message.size());
477
478 int32_t padded_message_length = static_cast<int32_t>(
479 PaddedLength(flatbuffer_size + prefix_size, options.alignment));
480
481 int32_t padding = padded_message_length - flatbuffer_size - prefix_size;
482
483 // The returned message size includes the length prefix, the flatbuffer,
484 // plus padding
485 *message_length = padded_message_length;
486
487 // ARROW-6314: Write continuation / padding token
488 if (!options.write_legacy_ipc_format) {
489 RETURN_NOT_OK(file->Write(&internal::kIpcContinuationToken, sizeof(int32_t)));
490 }
491
492 // Write the flatbuffer size prefix including padding in little endian
493 int32_t padded_flatbuffer_size =
494 BitUtil::ToLittleEndian(padded_message_length - prefix_size);
495 RETURN_NOT_OK(file->Write(&padded_flatbuffer_size, sizeof(int32_t)));
496
497 // Write the flatbuffer
498 RETURN_NOT_OK(file->Write(message.data(), flatbuffer_size));
499 if (padding > 0) {
500 RETURN_NOT_OK(file->Write(kPaddingBytes, padding));
501 }
502
503 return Status::OK();
504 }
505
506 // ----------------------------------------------------------------------
507 // Implement MessageDecoder
508
OnInitial()509 Status MessageDecoderListener::OnInitial() { return Status::OK(); }
OnMetadataLength()510 Status MessageDecoderListener::OnMetadataLength() { return Status::OK(); }
OnMetadata()511 Status MessageDecoderListener::OnMetadata() { return Status::OK(); }
OnBody()512 Status MessageDecoderListener::OnBody() { return Status::OK(); }
OnEOS()513 Status MessageDecoderListener::OnEOS() { return Status::OK(); }
514
515 static constexpr auto kMessageDecoderNextRequiredSizeInitial = sizeof(int32_t);
516 static constexpr auto kMessageDecoderNextRequiredSizeMetadataLength = sizeof(int32_t);
517
518 class MessageDecoder::MessageDecoderImpl {
519 public:
MessageDecoderImpl(std::shared_ptr<MessageDecoderListener> listener,State initial_state,int64_t initial_next_required_size,MemoryPool * pool)520 explicit MessageDecoderImpl(std::shared_ptr<MessageDecoderListener> listener,
521 State initial_state, int64_t initial_next_required_size,
522 MemoryPool* pool)
523 : listener_(std::move(listener)),
524 pool_(pool),
525 state_(initial_state),
526 next_required_size_(initial_next_required_size),
527 chunks_(),
528 buffered_size_(0),
529 metadata_(nullptr) {}
530
ConsumeData(const uint8_t * data,int64_t size)531 Status ConsumeData(const uint8_t* data, int64_t size) {
532 if (buffered_size_ == 0) {
533 while (size > 0 && size >= next_required_size_) {
534 auto used_size = next_required_size_;
535 switch (state_) {
536 case State::INITIAL:
537 RETURN_NOT_OK(ConsumeInitialData(data, next_required_size_));
538 break;
539 case State::METADATA_LENGTH:
540 RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_));
541 break;
542 case State::METADATA: {
543 auto buffer = std::make_shared<Buffer>(data, next_required_size_);
544 RETURN_NOT_OK(ConsumeMetadataBuffer(buffer));
545 } break;
546 case State::BODY: {
547 auto buffer = std::make_shared<Buffer>(data, next_required_size_);
548 RETURN_NOT_OK(ConsumeBodyBuffer(buffer));
549 } break;
550 case State::EOS:
551 return Status::OK();
552 }
553 data += used_size;
554 size -= used_size;
555 }
556 }
557
558 if (size == 0) {
559 return Status::OK();
560 }
561
562 chunks_.push_back(std::make_shared<Buffer>(data, size));
563 buffered_size_ += size;
564 return ConsumeChunks();
565 }
566
ConsumeBuffer(std::shared_ptr<Buffer> buffer)567 Status ConsumeBuffer(std::shared_ptr<Buffer> buffer) {
568 if (buffered_size_ == 0) {
569 while (buffer->size() >= next_required_size_) {
570 auto used_size = next_required_size_;
571 switch (state_) {
572 case State::INITIAL:
573 RETURN_NOT_OK(ConsumeInitialBuffer(buffer));
574 break;
575 case State::METADATA_LENGTH:
576 RETURN_NOT_OK(ConsumeMetadataLengthBuffer(buffer));
577 break;
578 case State::METADATA:
579 if (buffer->size() == next_required_size_) {
580 return ConsumeMetadataBuffer(buffer);
581 } else {
582 auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
583 RETURN_NOT_OK(ConsumeMetadataBuffer(sliced_buffer));
584 }
585 break;
586 case State::BODY:
587 if (buffer->size() == next_required_size_) {
588 return ConsumeBodyBuffer(buffer);
589 } else {
590 auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
591 RETURN_NOT_OK(ConsumeBodyBuffer(sliced_buffer));
592 }
593 break;
594 case State::EOS:
595 return Status::OK();
596 }
597 if (buffer->size() == used_size) {
598 return Status::OK();
599 }
600 buffer = SliceBuffer(buffer, used_size);
601 }
602 }
603
604 if (buffer->size() == 0) {
605 return Status::OK();
606 }
607
608 buffered_size_ += buffer->size();
609 chunks_.push_back(std::move(buffer));
610 return ConsumeChunks();
611 }
612
next_required_size() const613 int64_t next_required_size() const { return next_required_size_ - buffered_size_; }
614
state() const615 MessageDecoder::State state() const { return state_; }
616
617 private:
ConsumeChunks()618 Status ConsumeChunks() {
619 while (state_ != State::EOS) {
620 if (buffered_size_ < next_required_size_) {
621 return Status::OK();
622 }
623
624 switch (state_) {
625 case State::INITIAL:
626 RETURN_NOT_OK(ConsumeInitialChunks());
627 break;
628 case State::METADATA_LENGTH:
629 RETURN_NOT_OK(ConsumeMetadataLengthChunks());
630 break;
631 case State::METADATA:
632 RETURN_NOT_OK(ConsumeMetadataChunks());
633 break;
634 case State::BODY:
635 RETURN_NOT_OK(ConsumeBodyChunks());
636 break;
637 case State::EOS:
638 return Status::OK();
639 }
640 }
641
642 return Status::OK();
643 }
644
ConsumeInitialData(const uint8_t * data,int64_t size)645 Status ConsumeInitialData(const uint8_t* data, int64_t size) {
646 return ConsumeInitial(BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
647 }
648
ConsumeInitialBuffer(const std::shared_ptr<Buffer> & buffer)649 Status ConsumeInitialBuffer(const std::shared_ptr<Buffer>& buffer) {
650 ARROW_ASSIGN_OR_RAISE(auto continuation, ConsumeDataBufferInt32(buffer));
651 return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
652 }
653
ConsumeInitialChunks()654 Status ConsumeInitialChunks() {
655 int32_t continuation = 0;
656 RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &continuation));
657 return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
658 }
659
ConsumeInitial(int32_t continuation)660 Status ConsumeInitial(int32_t continuation) {
661 if (continuation == internal::kIpcContinuationToken) {
662 state_ = State::METADATA_LENGTH;
663 next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength;
664 RETURN_NOT_OK(listener_->OnMetadataLength());
665 // Valid IPC message, read the message length now
666 return Status::OK();
667 } else if (continuation == 0) {
668 state_ = State::EOS;
669 next_required_size_ = 0;
670 RETURN_NOT_OK(listener_->OnEOS());
671 return Status::OK();
672 } else if (continuation > 0) {
673 state_ = State::METADATA;
674 // ARROW-6314: Backwards compatibility for reading old IPC
675 // messages produced prior to version 0.15.0
676 next_required_size_ = continuation;
677 RETURN_NOT_OK(listener_->OnMetadata());
678 return Status::OK();
679 } else {
680 return Status::IOError("Invalid IPC stream: negative continuation token");
681 }
682 }
683
ConsumeMetadataLengthData(const uint8_t * data,int64_t size)684 Status ConsumeMetadataLengthData(const uint8_t* data, int64_t size) {
685 return ConsumeMetadataLength(
686 BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
687 }
688
ConsumeMetadataLengthBuffer(const std::shared_ptr<Buffer> & buffer)689 Status ConsumeMetadataLengthBuffer(const std::shared_ptr<Buffer>& buffer) {
690 ARROW_ASSIGN_OR_RAISE(auto metadata_length, ConsumeDataBufferInt32(buffer));
691 return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
692 }
693
ConsumeMetadataLengthChunks()694 Status ConsumeMetadataLengthChunks() {
695 int32_t metadata_length = 0;
696 RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &metadata_length));
697 return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
698 }
699
ConsumeMetadataLength(int32_t metadata_length)700 Status ConsumeMetadataLength(int32_t metadata_length) {
701 if (metadata_length == 0) {
702 state_ = State::EOS;
703 next_required_size_ = 0;
704 RETURN_NOT_OK(listener_->OnEOS());
705 return Status::OK();
706 } else if (metadata_length > 0) {
707 state_ = State::METADATA;
708 next_required_size_ = metadata_length;
709 RETURN_NOT_OK(listener_->OnMetadata());
710 return Status::OK();
711 } else {
712 return Status::IOError("Invalid IPC message: negative metadata length");
713 }
714 }
715
ConsumeMetadataBuffer(const std::shared_ptr<Buffer> & buffer)716 Status ConsumeMetadataBuffer(const std::shared_ptr<Buffer>& buffer) {
717 if (buffer->is_cpu()) {
718 metadata_ = buffer;
719 } else {
720 ARROW_ASSIGN_OR_RAISE(metadata_,
721 Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
722 }
723 return ConsumeMetadata();
724 }
725
ConsumeMetadataChunks()726 Status ConsumeMetadataChunks() {
727 if (chunks_[0]->size() >= next_required_size_) {
728 if (chunks_[0]->size() == next_required_size_) {
729 if (chunks_[0]->is_cpu()) {
730 metadata_ = std::move(chunks_[0]);
731 } else {
732 ARROW_ASSIGN_OR_RAISE(
733 metadata_,
734 Buffer::ViewOrCopy(chunks_[0], CPUDevice::memory_manager(pool_)));
735 }
736 chunks_.erase(chunks_.begin());
737 } else {
738 metadata_ = SliceBuffer(chunks_[0], 0, next_required_size_);
739 if (!chunks_[0]->is_cpu()) {
740 ARROW_ASSIGN_OR_RAISE(
741 metadata_, Buffer::ViewOrCopy(metadata_, CPUDevice::memory_manager(pool_)));
742 }
743 chunks_[0] = SliceBuffer(chunks_[0], next_required_size_);
744 }
745 buffered_size_ -= next_required_size_;
746 } else {
747 ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_));
748 metadata_ = std::shared_ptr<Buffer>(metadata.release());
749 RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data()));
750 }
751 return ConsumeMetadata();
752 }
753
ConsumeMetadata()754 Status ConsumeMetadata() {
755 RETURN_NOT_OK(MaybeAlignMetadata(&metadata_));
756 int64_t body_length = -1;
757 RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata_, &body_length));
758
759 state_ = State::BODY;
760 next_required_size_ = body_length;
761 RETURN_NOT_OK(listener_->OnBody());
762 if (next_required_size_ == 0) {
763 ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_));
764 std::shared_ptr<Buffer> shared_body(body.release());
765 return ConsumeBody(&shared_body);
766 } else {
767 return Status::OK();
768 }
769 }
770
ConsumeBodyBuffer(std::shared_ptr<Buffer> buffer)771 Status ConsumeBodyBuffer(std::shared_ptr<Buffer> buffer) {
772 return ConsumeBody(&buffer);
773 }
774
ConsumeBodyChunks()775 Status ConsumeBodyChunks() {
776 if (chunks_[0]->size() >= next_required_size_) {
777 auto used_size = next_required_size_;
778 if (chunks_[0]->size() == next_required_size_) {
779 RETURN_NOT_OK(ConsumeBody(&chunks_[0]));
780 chunks_.erase(chunks_.begin());
781 } else {
782 auto body = SliceBuffer(chunks_[0], 0, next_required_size_);
783 RETURN_NOT_OK(ConsumeBody(&body));
784 chunks_[0] = SliceBuffer(chunks_[0], used_size);
785 }
786 buffered_size_ -= used_size;
787 return Status::OK();
788 } else {
789 ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_));
790 RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data()));
791 std::shared_ptr<Buffer> shared_body(body.release());
792 return ConsumeBody(&shared_body);
793 }
794 }
795
ConsumeBody(std::shared_ptr<Buffer> * buffer)796 Status ConsumeBody(std::shared_ptr<Buffer>* buffer) {
797 ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
798 Message::Open(metadata_, *buffer));
799
800 RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message)));
801 state_ = State::INITIAL;
802 next_required_size_ = kMessageDecoderNextRequiredSizeInitial;
803 RETURN_NOT_OK(listener_->OnInitial());
804 return Status::OK();
805 }
806
ConsumeDataBufferInt32(const std::shared_ptr<Buffer> & buffer)807 Result<int32_t> ConsumeDataBufferInt32(const std::shared_ptr<Buffer>& buffer) {
808 if (buffer->is_cpu()) {
809 return util::SafeLoadAs<int32_t>(buffer->data());
810 } else {
811 ARROW_ASSIGN_OR_RAISE(auto cpu_buffer,
812 Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
813 return util::SafeLoadAs<int32_t>(cpu_buffer->data());
814 }
815 }
816
ConsumeDataChunks(int64_t nbytes,void * out)817 Status ConsumeDataChunks(int64_t nbytes, void* out) {
818 size_t offset = 0;
819 size_t n_used_chunks = 0;
820 auto required_size = nbytes;
821 std::shared_ptr<Buffer> last_chunk;
822 for (auto& chunk : chunks_) {
823 if (!chunk->is_cpu()) {
824 ARROW_ASSIGN_OR_RAISE(
825 chunk, Buffer::ViewOrCopy(chunk, CPUDevice::memory_manager(pool_)));
826 }
827 auto data = chunk->data();
828 auto data_size = chunk->size();
829 auto copy_size = std::min(required_size, data_size);
830 memcpy(static_cast<uint8_t*>(out) + offset, data, copy_size);
831 n_used_chunks++;
832 offset += copy_size;
833 required_size -= copy_size;
834 if (required_size == 0) {
835 if (data_size != copy_size) {
836 last_chunk = SliceBuffer(chunk, copy_size);
837 }
838 break;
839 }
840 }
841 chunks_.erase(chunks_.begin(), chunks_.begin() + n_used_chunks);
842 if (last_chunk.get() != nullptr) {
843 chunks_.insert(chunks_.begin(), std::move(last_chunk));
844 }
845 buffered_size_ -= offset;
846 return Status::OK();
847 }
848
849 std::shared_ptr<MessageDecoderListener> listener_;
850 MemoryPool* pool_;
851 State state_;
852 int64_t next_required_size_;
853 std::vector<std::shared_ptr<Buffer>> chunks_;
854 int64_t buffered_size_;
855 std::shared_ptr<Buffer> metadata_; // Must be CPU buffer
856 };
857
MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,MemoryPool * pool)858 MessageDecoder::MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,
859 MemoryPool* pool) {
860 impl_.reset(new MessageDecoderImpl(std::move(listener), State::INITIAL,
861 kMessageDecoderNextRequiredSizeInitial, pool));
862 }
863
MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,State initial_state,int64_t initial_next_required_size,MemoryPool * pool)864 MessageDecoder::MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,
865 State initial_state, int64_t initial_next_required_size,
866 MemoryPool* pool) {
867 impl_.reset(new MessageDecoderImpl(std::move(listener), initial_state,
868 initial_next_required_size, pool));
869 }
870
~MessageDecoder()871 MessageDecoder::~MessageDecoder() {}
872
Consume(const uint8_t * data,int64_t size)873 Status MessageDecoder::Consume(const uint8_t* data, int64_t size) {
874 return impl_->ConsumeData(data, size);
875 }
876
Consume(std::shared_ptr<Buffer> buffer)877 Status MessageDecoder::Consume(std::shared_ptr<Buffer> buffer) {
878 return impl_->ConsumeBuffer(buffer);
879 }
880
next_required_size() const881 int64_t MessageDecoder::next_required_size() const { return impl_->next_required_size(); }
882
state() const883 MessageDecoder::State MessageDecoder::state() const { return impl_->state(); }
884
885 // ----------------------------------------------------------------------
886 // Implement InputStream message reader
887
888 /// \brief Implementation of MessageReader that reads from InputStream
889 class InputStreamMessageReader : public MessageReader, public MessageDecoderListener {
890 public:
InputStreamMessageReader(io::InputStream * stream)891 explicit InputStreamMessageReader(io::InputStream* stream)
892 : stream_(stream),
893 owned_stream_(),
894 message_(),
895 decoder_(std::shared_ptr<InputStreamMessageReader>(this, [](void*) {})) {}
896
InputStreamMessageReader(const std::shared_ptr<io::InputStream> & owned_stream)897 explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream)
898 : InputStreamMessageReader(owned_stream.get()) {
899 owned_stream_ = owned_stream;
900 }
901
~InputStreamMessageReader()902 ~InputStreamMessageReader() {}
903
OnMessageDecoded(std::unique_ptr<Message> message)904 Status OnMessageDecoded(std::unique_ptr<Message> message) override {
905 message_ = std::move(message);
906 return Status::OK();
907 }
908
ReadNextMessage()909 Result<std::unique_ptr<Message>> ReadNextMessage() override {
910 ARROW_RETURN_NOT_OK(DecodeMessage(&decoder_, stream_));
911 return std::move(message_);
912 }
913
914 private:
915 io::InputStream* stream_;
916 std::shared_ptr<io::InputStream> owned_stream_;
917 std::unique_ptr<Message> message_;
918 MessageDecoder decoder_;
919 };
920
Open(io::InputStream * stream)921 std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) {
922 return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream));
923 }
924
Open(const std::shared_ptr<io::InputStream> & owned_stream)925 std::unique_ptr<MessageReader> MessageReader::Open(
926 const std::shared_ptr<io::InputStream>& owned_stream) {
927 return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream));
928 }
929
930 } // namespace ipc
931 } // namespace arrow
932