1 /*!
2  *  Copyright (c) 2015 by Contributors
3  * \file threaded_input_split.h
4  * \brief a threaded version of InputSplit with a prefetch thread
5  * \author Tianqi Chen
6  */
7 #ifndef DMLC_IO_THREADED_INPUT_SPLIT_H_
8 #define DMLC_IO_THREADED_INPUT_SPLIT_H_
9 
10 #include <dmlc/base.h>
11 // this code depends on c++11
12 #if DMLC_ENABLE_STD_THREAD
13 #include <dmlc/threadediter.h>
14 #include <algorithm>
15 #include "./input_split_base.h"
16 
17 namespace dmlc {
18 namespace io {
19 /*!
20  * \brief a threaded version of InputSplit
21  *  wraps an InputSplitBase to use an thread to prefetch the data
22  */
23 class ThreadedInputSplit : public InputSplit {
24  public:
25   /*!
26    * \brief constructor
27    * \param base an base object to define how to read data
28    */
ThreadedInputSplit(InputSplitBase * base,const size_t batch_size)29   explicit ThreadedInputSplit(InputSplitBase *base, const size_t batch_size)
30       : buffer_size_(InputSplitBase::kBufferSize),
31         batch_size_(batch_size),
32         base_(base), tmp_chunk_(NULL) {
33     iter_.set_max_capacity(2);
34     // initalize the iterator
35     iter_.Init([this](InputSplitBase::Chunk **dptr) {
36         if (*dptr == NULL) {
37           *dptr = new InputSplitBase::Chunk(buffer_size_);
38         }
39         return base_->NextBatchEx(*dptr, batch_size_);
40       },
41       [base]() { base->BeforeFirst(); });
42   }
43   // destructor
~ThreadedInputSplit(void)44   virtual ~ThreadedInputSplit(void) {
45     iter_.Destroy();
46     delete tmp_chunk_;
47     delete base_;
48   }
BeforeFirst()49   virtual void BeforeFirst() {
50     iter_.BeforeFirst();
51     if (tmp_chunk_ != NULL) {
52       iter_.Recycle(&tmp_chunk_);
53     }
54   }
HintChunkSize(size_t chunk_size)55   virtual void HintChunkSize(size_t chunk_size) {
56     buffer_size_ = std::max(chunk_size / sizeof(uint32_t), buffer_size_);
57   }
58   // implement next record
NextRecord(Blob * out_rec)59   virtual bool NextRecord(Blob *out_rec) {
60     if (tmp_chunk_ == NULL) {
61       if (!iter_.Next(&tmp_chunk_)) return false;
62     }
63     while (!base_->ExtractNextRecord(out_rec, tmp_chunk_)) {
64       iter_.Recycle(&tmp_chunk_);
65       if (!iter_.Next(&tmp_chunk_)) return false;
66     }
67     return true;
68   }
69   // implement next chunk
NextChunk(Blob * out_chunk)70   virtual bool NextChunk(Blob *out_chunk) {
71     if (tmp_chunk_ == NULL) {
72       if (!iter_.Next(&tmp_chunk_)) return false;
73     }
74     while (!base_->ExtractNextChunk(out_chunk, tmp_chunk_)) {
75       iter_.Recycle(&tmp_chunk_);
76       if (!iter_.Next(&tmp_chunk_)) return false;
77     }
78     return true;
79   }
80 
GetTotalSize(void)81   virtual size_t GetTotalSize(void) {
82     return base_->GetTotalSize();
83   }
84 
ResetPartition(unsigned part_index,unsigned num_parts)85   virtual void ResetPartition(unsigned part_index, unsigned num_parts) {
86     base_->ResetPartition(part_index, num_parts);
87     this->BeforeFirst();
88   }
89 
90  private:
91   /*! \brief internal buffer size */
92   size_t buffer_size_;
93   /*! \brief batch size */
94   size_t batch_size_;
95   /*! \brief the place where we get the data */
96   InputSplitBase *base_;
97   /*! \brief backend thread iterator */
98   ThreadedIter<InputSplitBase::Chunk> iter_;
99   /*! \brief current chunk of data */
100   InputSplitBase::Chunk *tmp_chunk_;
101 };
102 }  // namespace io
103 }  // namespace dmlc
104 #endif  // DMLC_USE_CXX11
105 #endif  // DMLC_IO_THREADED_INPUT_SPLIT_H_
106