1 /*! 2 * Copyright (c) 2015 by Contributors 3 * \file threaded_input_split.h 4 * \brief a threaded version of InputSplit with a prefetch thread 5 * \author Tianqi Chen 6 */ 7 #ifndef DMLC_IO_THREADED_INPUT_SPLIT_H_ 8 #define DMLC_IO_THREADED_INPUT_SPLIT_H_ 9 10 #include <dmlc/base.h> 11 // this code depends on c++11 12 #if DMLC_ENABLE_STD_THREAD 13 #include <dmlc/threadediter.h> 14 #include <algorithm> 15 #include "./input_split_base.h" 16 17 namespace dmlc { 18 namespace io { 19 /*! 20 * \brief a threaded version of InputSplit 21 * wraps an InputSplitBase to use an thread to prefetch the data 22 */ 23 class ThreadedInputSplit : public InputSplit { 24 public: 25 /*! 26 * \brief constructor 27 * \param base an base object to define how to read data 28 */ ThreadedInputSplit(InputSplitBase * base,const size_t batch_size)29 explicit ThreadedInputSplit(InputSplitBase *base, const size_t batch_size) 30 : buffer_size_(InputSplitBase::kBufferSize), 31 batch_size_(batch_size), 32 base_(base), tmp_chunk_(NULL) { 33 iter_.set_max_capacity(2); 34 // initalize the iterator 35 iter_.Init([this](InputSplitBase::Chunk **dptr) { 36 if (*dptr == NULL) { 37 *dptr = new InputSplitBase::Chunk(buffer_size_); 38 } 39 return base_->NextBatchEx(*dptr, batch_size_); 40 }, 41 [base]() { base->BeforeFirst(); }); 42 } 43 // destructor ~ThreadedInputSplit(void)44 virtual ~ThreadedInputSplit(void) { 45 iter_.Destroy(); 46 delete tmp_chunk_; 47 delete base_; 48 } BeforeFirst()49 virtual void BeforeFirst() { 50 iter_.BeforeFirst(); 51 if (tmp_chunk_ != NULL) { 52 iter_.Recycle(&tmp_chunk_); 53 } 54 } HintChunkSize(size_t chunk_size)55 virtual void HintChunkSize(size_t chunk_size) { 56 buffer_size_ = std::max(chunk_size / sizeof(uint32_t), buffer_size_); 57 } 58 // implement next record NextRecord(Blob * out_rec)59 virtual bool NextRecord(Blob *out_rec) { 60 if (tmp_chunk_ == NULL) { 61 if (!iter_.Next(&tmp_chunk_)) return false; 62 } 63 while (!base_->ExtractNextRecord(out_rec, tmp_chunk_)) { 64 iter_.Recycle(&tmp_chunk_); 65 if (!iter_.Next(&tmp_chunk_)) return false; 66 } 67 return true; 68 } 69 // implement next chunk NextChunk(Blob * out_chunk)70 virtual bool NextChunk(Blob *out_chunk) { 71 if (tmp_chunk_ == NULL) { 72 if (!iter_.Next(&tmp_chunk_)) return false; 73 } 74 while (!base_->ExtractNextChunk(out_chunk, tmp_chunk_)) { 75 iter_.Recycle(&tmp_chunk_); 76 if (!iter_.Next(&tmp_chunk_)) return false; 77 } 78 return true; 79 } 80 GetTotalSize(void)81 virtual size_t GetTotalSize(void) { 82 return base_->GetTotalSize(); 83 } 84 ResetPartition(unsigned part_index,unsigned num_parts)85 virtual void ResetPartition(unsigned part_index, unsigned num_parts) { 86 base_->ResetPartition(part_index, num_parts); 87 this->BeforeFirst(); 88 } 89 90 private: 91 /*! \brief internal buffer size */ 92 size_t buffer_size_; 93 /*! \brief batch size */ 94 size_t batch_size_; 95 /*! \brief the place where we get the data */ 96 InputSplitBase *base_; 97 /*! \brief backend thread iterator */ 98 ThreadedIter<InputSplitBase::Chunk> iter_; 99 /*! \brief current chunk of data */ 100 InputSplitBase::Chunk *tmp_chunk_; 101 }; 102 } // namespace io 103 } // namespace dmlc 104 #endif // DMLC_USE_CXX11 105 #endif // DMLC_IO_THREADED_INPUT_SPLIT_H_ 106