1 // Copyright by Contributors
2 #include <dmlc/recordio.h>
3 #include <dmlc/logging.h>
4 #include <dmlc/io.h>
5 #include <algorithm>
6 #include <fstream>
7 #include "./indexed_recordio_split.h"
8
9 namespace dmlc {
10 namespace io {
11
ResetPartition(unsigned rank,unsigned nsplit)12 void IndexedRecordIOSplitter::ResetPartition(unsigned rank, unsigned nsplit) {
13 size_t ntotal = index_.size();
14 size_t ntotalbytes = file_offset_.back();
15 size_t nstep = (ntotal + nsplit - 1) / nsplit;
16 if (rank * nstep >= ntotal) return;
17 index_begin_ = rank * nstep;
18 offset_begin_ = index_[index_begin_].first;
19 if ((rank + 1) * nstep < ntotal) {
20 index_end_ = (rank + 1) * nstep;
21 offset_end_ = index_[index_end_].first;
22 } else {
23 offset_end_ = ntotalbytes;
24 index_end_ = index_.size();
25 index_.push_back(std::make_pair(offset_end_, 0));
26 }
27 offset_curr_ = offset_begin_;
28 file_ptr_ = std::upper_bound(file_offset_.begin(),
29 file_offset_.end(),
30 offset_begin_) - file_offset_.begin() - 1;
31 file_ptr_end_ = std::upper_bound(file_offset_.begin(),
32 file_offset_.end(),
33 offset_end_) - file_offset_.begin() - 1;
34 if (fs_ != NULL) {
35 delete fs_; fs_ = NULL;
36 }
37 fs_ = filesys_->OpenForRead(files_[file_ptr_].path);
38 current_index_ = index_begin_;
39 n_overflow_ = 0;
40 this->BeforeFirst();
41 }
42
ReadIndexFile(FileSystem * fs,const std::string & index_uri)43 void IndexedRecordIOSplitter::ReadIndexFile(FileSystem *fs, const std::string& index_uri) {
44 std::vector<URI> expanded_list = this->ConvertToURIs(index_uri);
45 CHECK_EQ(expanded_list.size(), 1ul)
46 << "IndexedRecordIOSplitter does not support multiple index files";
47 for (size_t i = 0; i < expanded_list.size(); ++i) {
48 const URI& path = expanded_list[i];
49 std::unique_ptr<dmlc::Stream> file_stream(fs->Open(path, "r", true));
50 dmlc::istream index_file(file_stream.get());
51 std::vector<size_t> temp;
52 size_t index, offset;
53 while (index_file >> index >> offset) {
54 temp.push_back(offset);
55 }
56 std::sort(temp.begin(), temp.end());
57 for (size_t j = 0; j < temp.size() - 1; ++j) {
58 index_.push_back(std::make_pair(temp[j], temp[j + 1] - temp[j]));
59 }
60 index_.push_back(std::make_pair(temp.back(), file_offset_.back() - temp.back()));
61 }
62 }
63
64 // Inefficient, but not used anywhere and optimization
65 // would require change of the API, so I leave it as is
SeekRecordBegin(Stream * fi)66 size_t IndexedRecordIOSplitter::SeekRecordBegin(Stream *fi) {
67 size_t nstep = 0;
68 uint32_t v, lrec;
69 while (true) {
70 if (fi->Read(&v, sizeof(v)) == 0) return nstep;
71 nstep += sizeof(v);
72 if (v == RecordIOWriter::kMagic) {
73 CHECK(fi->Read(&lrec, sizeof(lrec)) != 0)
74 << "invalid record io format";
75 nstep += sizeof(lrec);
76 uint32_t cflag = RecordIOWriter::DecodeFlag(lrec);
77 if (cflag == 0 || cflag == 1) break;
78 }
79 }
80 // should point at head of record
81 return nstep - 2 * sizeof(uint32_t);
82 }
83
84 // Inefficient, but not used anywhere and optimization
85 // would require change of the API, so I leave it as is
FindLastRecordBegin(const char * begin,const char * end)86 const char* IndexedRecordIOSplitter::FindLastRecordBegin(const char *begin,
87 const char *end) {
88 CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U);
89 CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
90 const uint32_t *pbegin = reinterpret_cast<const uint32_t *>(begin);
91 const uint32_t *p = reinterpret_cast<const uint32_t *>(end);
92 CHECK(p >= pbegin + 2);
93 for (p = p - 2; p != pbegin; --p) {
94 if (p[0] == RecordIOWriter::kMagic) {
95 uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
96 if (cflag == 0 || cflag == 1) {
97 return reinterpret_cast<const char*>(p);
98 }
99 }
100 }
101 return begin;
102 }
103
ExtractNextRecord(Blob * out_rec,Chunk * chunk)104 bool IndexedRecordIOSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) {
105 if (chunk->begin == chunk->end) return false;
106 CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end)
107 << "Invalid RecordIO Format";
108 CHECK_EQ((reinterpret_cast<size_t>(chunk->begin) & 3UL), 0U);
109 CHECK_EQ((reinterpret_cast<size_t>(chunk->end) & 3UL), 0U);
110 uint32_t *p = reinterpret_cast<uint32_t *>(chunk->begin);
111 uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
112 uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
113 // skip header
114 out_rec->dptr = chunk->begin + 2 * sizeof(uint32_t);
115 // move pbegin
116 chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
117 CHECK(chunk->begin <= chunk->end) << "Invalid RecordIO Format";
118 out_rec->size = clen;
119 if (cflag == 0) return true;
120 const uint32_t kMagic = RecordIOWriter::kMagic;
121 // abnormal path, move data around to make a full part
122 CHECK(cflag == 1U) << "Invalid RecordIO Format";
123 while (cflag != 3U) {
124 CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end);
125 p = reinterpret_cast<uint32_t *>(chunk->begin);
126 CHECK(p[0] == RecordIOWriter::kMagic);
127 cflag = RecordIOWriter::DecodeFlag(p[1]);
128 clen = RecordIOWriter::DecodeLength(p[1]);
129 // pad kmagic in between
130 std::memcpy(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
131 &kMagic, sizeof(kMagic));
132 out_rec->size += sizeof(kMagic);
133 // move the rest of the blobs
134 if (clen != 0) {
135 std::memmove(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
136 chunk->begin + 2 * sizeof(uint32_t), clen);
137 out_rec->size += clen;
138 }
139 chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
140 }
141 return true;
142 }
143
ReadChunk(void * buf,size_t * size)144 bool IndexedRecordIOSplitter::ReadChunk(void *buf, size_t *size) {
145 size_t max_size = *size;
146 size_t nread = this->Read(reinterpret_cast<char*>(buf),
147 max_size);
148 if (nread == 0) return false;
149 if (nread != max_size) {
150 *size = nread;
151 }
152 return true;
153 }
154
NextChunk(Blob * out_chunk)155 bool IndexedRecordIOSplitter::NextChunk(Blob *out_chunk) {
156 return this->NextBatch(out_chunk, batch_size_);
157 }
158
NextBatchEx(Chunk * chunk,size_t n_records)159 bool IndexedRecordIOSplitter::NextBatchEx(Chunk *chunk, size_t n_records) {
160 if (shuffle_) {
161 bool ret = true;
162 size_t n_read = 0;
163 size_t n = n_overflow_ == 0?n_records:n_overflow_;
164 while (n_read < n) {
165 if (current_index_ < permutation_.size()) {
166 offset_curr_ = index_[permutation_[current_index_]].first;
167 buffer_size_ = index_[permutation_[current_index_]].second/sizeof(uint32_t);
168 size_t new_file_ptr = std::upper_bound(file_offset_.begin(),
169 file_offset_.end(),
170 offset_curr_) - file_offset_.begin() - 1;
171 if (new_file_ptr != file_ptr_) {
172 delete fs_;
173 file_ptr_ = new_file_ptr;
174 fs_ = filesys_->OpenForRead(files_[file_ptr_].path);
175 }
176 fs_->Seek(offset_curr_ - file_offset_[file_ptr_]);
177 if (n_read == 0) {
178 ret = ret && chunk->Load(this, buffer_size_);
179 } else {
180 ret = ret && chunk->Append(this, buffer_size_);
181 }
182 if (ret) {
183 ++n_read;
184 ++current_index_;
185 } else {
186 break;
187 }
188 } else {
189 break;
190 }
191 }
192 if (n_read > 0) {
193 n_overflow_ = n - n_read;
194 return true;
195 } else {
196 return false;
197 }
198 } else {
199 size_t last;
200 if (n_overflow_ == 0) {
201 last = std::min(current_index_ + n_records, index_end_);
202 n_overflow_ = current_index_ + n_records - last;
203 } else {
204 last = std::min(current_index_ + n_overflow_, index_end_);
205 n_overflow_ = current_index_ + n_overflow_ - last;
206 }
207 buffer_size_ = (index_[last].first - index_[current_index_].first)/INDEXED_RECORDIO_ALIGN;
208 current_index_ = last;
209 return chunk->Load(this, buffer_size_);
210 }
211 return true;
212 }
213
NextBatch(Blob * out_chunk,size_t batch_size)214 bool IndexedRecordIOSplitter::NextBatch(Blob *out_chunk, size_t batch_size) {
215 while (!ExtractNextChunk(out_chunk, &tmp_chunk_)) {
216 if (!NextBatchEx(&tmp_chunk_, batch_size)) return false;
217 }
218 return true;
219 }
220
BeforeFirst(void)221 void IndexedRecordIOSplitter::BeforeFirst(void) {
222 if (shuffle_) {
223 permutation_.clear();
224 for (size_t i = index_begin_; i < index_end_; ++i) {
225 permutation_.push_back(i);
226 }
227 std::shuffle(permutation_.begin(), permutation_.end(), rnd_);
228 current_index_ = 0;
229 } else {
230 current_index_ = index_begin_;
231 }
232 InputSplitBase::BeforeFirst();
233 }
234 } // namespace io
235 } // namespace dmlc
236