1 // Copyright by Contributors
2 #include <dmlc/recordio.h>
3 #include <dmlc/logging.h>
4 #include <dmlc/io.h>
5 #include <algorithm>
6 #include <fstream>
7 #include "./indexed_recordio_split.h"
8 
9 namespace dmlc {
10 namespace io {
11 
ResetPartition(unsigned rank,unsigned nsplit)12 void IndexedRecordIOSplitter::ResetPartition(unsigned rank, unsigned nsplit) {
13   size_t ntotal = index_.size();
14   size_t ntotalbytes = file_offset_.back();
15   size_t nstep = (ntotal + nsplit - 1) / nsplit;
16   if (rank * nstep >= ntotal) return;
17   index_begin_ = rank * nstep;
18   offset_begin_ = index_[index_begin_].first;
19   if ((rank + 1) * nstep < ntotal) {
20     index_end_ = (rank + 1) * nstep;
21     offset_end_ = index_[index_end_].first;
22   } else {
23     offset_end_ = ntotalbytes;
24     index_end_ = index_.size();
25     index_.push_back(std::make_pair(offset_end_, 0));
26   }
27   offset_curr_ = offset_begin_;
28   file_ptr_ = std::upper_bound(file_offset_.begin(),
29                                file_offset_.end(),
30                                offset_begin_) - file_offset_.begin() - 1;
31   file_ptr_end_ = std::upper_bound(file_offset_.begin(),
32                                    file_offset_.end(),
33                                    offset_end_) - file_offset_.begin() - 1;
34   if (fs_ != NULL) {
35     delete fs_; fs_ = NULL;
36   }
37   fs_ = filesys_->OpenForRead(files_[file_ptr_].path);
38   current_index_ = index_begin_;
39   n_overflow_ = 0;
40   this->BeforeFirst();
41 }
42 
ReadIndexFile(FileSystem * fs,const std::string & index_uri)43 void IndexedRecordIOSplitter::ReadIndexFile(FileSystem *fs, const std::string& index_uri) {
44   std::vector<URI> expanded_list = this->ConvertToURIs(index_uri);
45   CHECK_EQ(expanded_list.size(), 1ul)
46     << "IndexedRecordIOSplitter does not support multiple index files";
47   for (size_t i = 0; i < expanded_list.size(); ++i) {
48     const URI& path = expanded_list[i];
49     std::unique_ptr<dmlc::Stream> file_stream(fs->Open(path, "r", true));
50     dmlc::istream index_file(file_stream.get());
51     std::vector<size_t> temp;
52     size_t index, offset;
53     while (index_file >> index >> offset) {
54       temp.push_back(offset);
55     }
56     std::sort(temp.begin(), temp.end());
57     for (size_t j = 0; j < temp.size() - 1; ++j) {
58       index_.push_back(std::make_pair(temp[j], temp[j + 1] - temp[j]));
59     }
60     index_.push_back(std::make_pair(temp.back(), file_offset_.back() - temp.back()));
61   }
62 }
63 
64 // Inefficient, but not used anywhere and optimization
65 // would require change of the API, so I leave it as is
SeekRecordBegin(Stream * fi)66 size_t IndexedRecordIOSplitter::SeekRecordBegin(Stream *fi) {
67   size_t nstep = 0;
68   uint32_t v, lrec;
69   while (true) {
70     if (fi->Read(&v, sizeof(v)) == 0) return nstep;
71     nstep += sizeof(v);
72     if (v == RecordIOWriter::kMagic) {
73       CHECK(fi->Read(&lrec, sizeof(lrec)) != 0)
74             << "invalid record io format";
75       nstep += sizeof(lrec);
76       uint32_t cflag = RecordIOWriter::DecodeFlag(lrec);
77       if (cflag == 0 || cflag == 1) break;
78     }
79   }
80   // should point at head of record
81   return nstep - 2 * sizeof(uint32_t);
82 }
83 
84 // Inefficient, but not used anywhere and optimization
85 // would require change of the API, so I leave it as is
FindLastRecordBegin(const char * begin,const char * end)86 const char* IndexedRecordIOSplitter::FindLastRecordBegin(const char *begin,
87                                                   const char *end) {
88   CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U);
89   CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
90   const uint32_t *pbegin = reinterpret_cast<const uint32_t *>(begin);
91   const uint32_t *p = reinterpret_cast<const uint32_t *>(end);
92   CHECK(p >= pbegin + 2);
93   for (p = p - 2; p != pbegin; --p) {
94     if (p[0] == RecordIOWriter::kMagic) {
95       uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
96       if (cflag == 0 || cflag == 1) {
97         return reinterpret_cast<const char*>(p);
98       }
99     }
100   }
101   return begin;
102 }
103 
ExtractNextRecord(Blob * out_rec,Chunk * chunk)104 bool IndexedRecordIOSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) {
105   if (chunk->begin == chunk->end) return false;
106   CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end)
107       << "Invalid RecordIO Format";
108   CHECK_EQ((reinterpret_cast<size_t>(chunk->begin) & 3UL), 0U);
109   CHECK_EQ((reinterpret_cast<size_t>(chunk->end) & 3UL), 0U);
110   uint32_t *p = reinterpret_cast<uint32_t *>(chunk->begin);
111   uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
112   uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
113   // skip header
114   out_rec->dptr = chunk->begin + 2 * sizeof(uint32_t);
115   // move pbegin
116   chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
117   CHECK(chunk->begin <= chunk->end) << "Invalid RecordIO Format";
118   out_rec->size = clen;
119   if (cflag == 0) return true;
120   const uint32_t kMagic = RecordIOWriter::kMagic;
121   // abnormal path, move data around to make a full part
122   CHECK(cflag == 1U) << "Invalid RecordIO Format";
123   while (cflag != 3U) {
124     CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end);
125     p = reinterpret_cast<uint32_t *>(chunk->begin);
126     CHECK(p[0] == RecordIOWriter::kMagic);
127     cflag = RecordIOWriter::DecodeFlag(p[1]);
128     clen = RecordIOWriter::DecodeLength(p[1]);
129     // pad kmagic in between
130     std::memcpy(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
131                 &kMagic, sizeof(kMagic));
132     out_rec->size += sizeof(kMagic);
133     // move the rest of the blobs
134     if (clen != 0) {
135       std::memmove(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
136                    chunk->begin + 2 * sizeof(uint32_t), clen);
137       out_rec->size += clen;
138     }
139     chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
140   }
141   return true;
142 }
143 
ReadChunk(void * buf,size_t * size)144 bool IndexedRecordIOSplitter::ReadChunk(void *buf, size_t *size) {
145   size_t max_size = *size;
146   size_t nread = this->Read(reinterpret_cast<char*>(buf),
147                             max_size);
148   if (nread == 0) return false;
149   if (nread != max_size) {
150     *size = nread;
151   }
152   return true;
153 }
154 
NextChunk(Blob * out_chunk)155 bool IndexedRecordIOSplitter::NextChunk(Blob *out_chunk) {
156   return this->NextBatch(out_chunk, batch_size_);
157 }
158 
NextBatchEx(Chunk * chunk,size_t n_records)159 bool IndexedRecordIOSplitter::NextBatchEx(Chunk *chunk, size_t n_records) {
160     if (shuffle_) {
161       bool ret = true;
162       size_t n_read = 0;
163       size_t n = n_overflow_ == 0?n_records:n_overflow_;
164       while (n_read < n) {
165         if (current_index_ < permutation_.size()) {
166           offset_curr_ = index_[permutation_[current_index_]].first;
167           buffer_size_ = index_[permutation_[current_index_]].second/sizeof(uint32_t);
168           size_t new_file_ptr = std::upper_bound(file_offset_.begin(),
169                                  file_offset_.end(),
170                                  offset_curr_) - file_offset_.begin() - 1;
171           if (new_file_ptr != file_ptr_) {
172             delete fs_;
173             file_ptr_ = new_file_ptr;
174             fs_ = filesys_->OpenForRead(files_[file_ptr_].path);
175           }
176           fs_->Seek(offset_curr_ - file_offset_[file_ptr_]);
177           if (n_read == 0) {
178             ret = ret && chunk->Load(this, buffer_size_);
179           } else {
180             ret = ret && chunk->Append(this, buffer_size_);
181           }
182           if (ret) {
183             ++n_read;
184             ++current_index_;
185           } else {
186             break;
187           }
188         } else {
189           break;
190         }
191       }
192       if (n_read > 0) {
193         n_overflow_ = n - n_read;
194         return true;
195       } else {
196         return false;
197       }
198     } else {
199       size_t last;
200       if (n_overflow_ == 0) {
201         last = std::min(current_index_ + n_records, index_end_);
202         n_overflow_ = current_index_ + n_records - last;
203       } else {
204         last = std::min(current_index_ + n_overflow_, index_end_);
205         n_overflow_ = current_index_ + n_overflow_ - last;
206       }
207       buffer_size_ = (index_[last].first - index_[current_index_].first)/INDEXED_RECORDIO_ALIGN;
208       current_index_ = last;
209       return chunk->Load(this, buffer_size_);
210     }
211     return true;
212 }
213 
NextBatch(Blob * out_chunk,size_t batch_size)214 bool IndexedRecordIOSplitter::NextBatch(Blob *out_chunk, size_t batch_size) {
215   while (!ExtractNextChunk(out_chunk, &tmp_chunk_)) {
216     if (!NextBatchEx(&tmp_chunk_, batch_size)) return false;
217   }
218   return true;
219 }
220 
BeforeFirst(void)221 void IndexedRecordIOSplitter::BeforeFirst(void) {
222   if (shuffle_) {
223     permutation_.clear();
224     for (size_t i = index_begin_; i < index_end_; ++i) {
225       permutation_.push_back(i);
226     }
227     std::shuffle(permutation_.begin(), permutation_.end(), rnd_);
228     current_index_ = 0;
229   } else {
230     current_index_ = index_begin_;
231   }
232   InputSplitBase::BeforeFirst();
233 }
234 }  // namespace io
235 }  // namespace dmlc
236