1 // Copyright by Contributors
2 #include <dmlc/recordio.h>
3 #include <dmlc/logging.h>
4 #include <algorithm>
5 #include "./recordio_split.h"
6
7 namespace dmlc {
8 namespace io {
SeekRecordBegin(Stream * fi)9 size_t RecordIOSplitter::SeekRecordBegin(Stream *fi) {
10 size_t nstep = 0;
11 uint32_t v, lrec;
12 while (true) {
13 if (fi->Read(&v, sizeof(v)) == 0) return nstep;
14 nstep += sizeof(v);
15 if (v == RecordIOWriter::kMagic) {
16 CHECK(fi->Read(&lrec, sizeof(lrec)) != 0)
17 << "invalid record io format";
18 nstep += sizeof(lrec);
19 uint32_t cflag = RecordIOWriter::DecodeFlag(lrec);
20 if (cflag == 0 || cflag == 1) break;
21 }
22 }
23 // should point at head of record
24 return nstep - 2 * sizeof(uint32_t);
25 }
FindLastRecordBegin(const char * begin,const char * end)26 const char* RecordIOSplitter::FindLastRecordBegin(const char *begin,
27 const char *end) {
28 CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U);
29 CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
30 const uint32_t *pbegin = reinterpret_cast<const uint32_t *>(begin);
31 const uint32_t *p = reinterpret_cast<const uint32_t *>(end);
32 CHECK(p >= pbegin + 2);
33 for (p = p - 2; p != pbegin; --p) {
34 if (p[0] == RecordIOWriter::kMagic) {
35 uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
36 if (cflag == 0 || cflag == 1) {
37 return reinterpret_cast<const char*>(p);
38 }
39 }
40 }
41 return begin;
42 }
43
ExtractNextRecord(Blob * out_rec,Chunk * chunk)44 bool RecordIOSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) {
45 if (chunk->begin == chunk->end) return false;
46 CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end)
47 << "Invalid RecordIO Format";
48 CHECK_EQ((reinterpret_cast<size_t>(chunk->begin) & 3UL), 0U);
49 CHECK_EQ((reinterpret_cast<size_t>(chunk->end) & 3UL), 0U);
50 uint32_t *p = reinterpret_cast<uint32_t *>(chunk->begin);
51 uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
52 uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
53 // skip header
54 out_rec->dptr = chunk->begin + 2 * sizeof(uint32_t);
55 // move pbegin
56 chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
57 CHECK(chunk->begin <= chunk->end) << "Invalid RecordIO Format";
58 out_rec->size = clen;
59 if (cflag == 0) return true;
60 const uint32_t kMagic = RecordIOWriter::kMagic;
61 // abnormal path, move data around to make a full part
62 CHECK(cflag == 1U) << "Invalid RecordIO Format";
63 while (cflag != 3U) {
64 CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end);
65 p = reinterpret_cast<uint32_t *>(chunk->begin);
66 CHECK(p[0] == RecordIOWriter::kMagic);
67 cflag = RecordIOWriter::DecodeFlag(p[1]);
68 clen = RecordIOWriter::DecodeLength(p[1]);
69 // pad kmagic in between
70 std::memcpy(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
71 &kMagic, sizeof(kMagic));
72 out_rec->size += sizeof(kMagic);
73 // move the rest of the blobs
74 if (clen != 0) {
75 std::memmove(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
76 chunk->begin + 2 * sizeof(uint32_t), clen);
77 out_rec->size += clen;
78 }
79 chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
80 }
81 return true;
82 }
83 } // namespace io
84 } // namespace dmlc
85